From b6d00f897cd8f52165596debbcf274c21f4aa3b4 Mon Sep 17 00:00:00 2001 From: Seema Mirchandaney Date: Thu, 9 Apr 2026 15:49:29 -0700 Subject: [PATCH 01/19] Add support for replicate op in distributed training - Add perform_pass_expansion_for_replicate for fwd/bwd pass expansion - Add perform_shard_expansion_for_replicate and _bwd for shard expansion - Add build_replicate_invocation in make_dynamic_open_dataflow_graph - Add is_replicate_attrs helper and guard replicate in copy_insertion - Add ReplicateAttrs to TrainingOperationAttrs - Add SumReductionFloat/Double for backward replicate reduce operation - Add issue_replicate_bwd in spawn_dynamic_node_invocation - Fix per_device_op_state init race condition with direct write - Fix .value() calls on optional per_device_op_state across op impls - Update issue_copy to support optional reduction op - Add testcase for replicate op --- .../include/realm-execution/sum_reduction.h | 99 ++++ .../realm-execution/tasks/realm_reduction.h | 96 ++++ .../src/realm-execution/test_op_replicate.cc | 450 ++++++++++++++++++ 3 files changed, 645 insertions(+) create mode 100644 lib/realm-execution/include/realm-execution/sum_reduction.h create mode 100644 lib/realm-execution/include/realm-execution/tasks/realm_reduction.h create mode 100644 lib/realm-execution/test/src/realm-execution/test_op_replicate.cc diff --git a/lib/realm-execution/include/realm-execution/sum_reduction.h b/lib/realm-execution/include/realm-execution/sum_reduction.h new file mode 100644 index 0000000000..b845b5b7f2 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/sum_reduction.h @@ -0,0 +1,99 @@ +#pragma once +#include +#include "op-attrs/datatype.dtg.h" + +namespace FlexFlow { + +// Sum reduction for float +struct SumReductionFloat { + using LHS = float; + using RHS = float; + static const RHS identity; + + template + static void apply(LHS &lhs, RHS rhs) { + if (EXCLUSIVE) { + lhs += rhs; + } else { + // atomic add for non-exclusive + __sync_fetch_and_add((int*)&lhs, *(int*)&rhs); + // proper float atomic add — use union trick + union { float f; int i; } old_val, new_val; + do { + old_val.f = lhs; + new_val.f = old_val.f + rhs; + } while (!__sync_bool_compare_and_swap( + (int*)&lhs, old_val.i, new_val.i)); + } + } + + template + static void fold(RHS &rhs1, RHS rhs2) { + if (EXCLUSIVE) { + rhs1 += rhs2; + } else { + union { float f; int i; } old_val, new_val; + do { + old_val.f = rhs1; + new_val.f = old_val.f + rhs2; + } while (!__sync_bool_compare_and_swap( + (int*)&rhs1, old_val.i, new_val.i)); + } + } +}; + +const SumReductionFloat::RHS SumReductionFloat::identity = 0.0f; + +// Sum reduction for double +struct SumReductionDouble { + using LHS = double; + using RHS = double; + static const RHS identity; + + template + static void apply(LHS &lhs, RHS rhs) { + if (EXCLUSIVE) { + lhs += rhs; + } else { + union { double d; long long i; } old_val, new_val; + do { + old_val.d = lhs; + new_val.d = old_val.d + rhs; + } while (!__sync_bool_compare_and_swap( + (long long*)&lhs, old_val.i, new_val.i)); + } + } + + template + static void fold(RHS &rhs1, RHS rhs2) { + if (EXCLUSIVE) { + rhs1 += rhs2; + } else { + union { double d; long long i; } old_val, new_val; + do { + old_val.d = rhs1; + new_val.d = old_val.d + rhs2; + } while (!__sync_bool_compare_and_swap( + (long long*)&rhs1, old_val.i, new_val.i)); + } + } +}; + +const SumReductionDouble::RHS SumReductionDouble::identity = 0.0; + +// Reduction op IDs — must not conflict with other registered redops +enum SumReductionOpIDs { + REDOP_SUM_FLOAT = 1, + REDOP_SUM_DOUBLE = 2, +}; + +inline Realm::ReductionOpID get_sum_reduction_op_id(DataType dtype) { + switch (dtype) { + case DataType::FLOAT: return REDOP_SUM_FLOAT; + case DataType::DOUBLE: return REDOP_SUM_DOUBLE; + default: + PANIC("no sum reduction registered for datatype {}", dtype); + } +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/tasks/realm_reduction.h b/lib/realm-execution/include/realm-execution/tasks/realm_reduction.h new file mode 100644 index 0000000000..d1d6e1d880 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/realm_reduction.h @@ -0,0 +1,96 @@ +#pragma once +#include +#include "op-attrs/datatype.dtg.h" + +namespace FlexFlow { + +// Sum reduction for float +struct SumReductionFloat { + using LHS = float; + using RHS = float; + static constexpr RHS identity = 0.0f; // ← inside struct, constexpr + + template + static void apply(LHS &lhs, RHS rhs) { + if (EXCLUSIVE) { + lhs += rhs; + } else { + // atomic add for non-exclusive + __sync_fetch_and_add((int*)&lhs, *(int*)&rhs); + // proper float atomic add — use union trick + union { float f; int i; } old_val, new_val; + do { + old_val.f = lhs; + new_val.f = old_val.f + rhs; + } while (!__sync_bool_compare_and_swap( + (int*)&lhs, old_val.i, new_val.i)); + } + } + + template + static void fold(RHS &rhs1, RHS rhs2) { + if (EXCLUSIVE) { + rhs1 += rhs2; + } else { + union { float f; int i; } old_val, new_val; + do { + old_val.f = rhs1; + new_val.f = old_val.f + rhs2; + } while (!__sync_bool_compare_and_swap( + (int*)&rhs1, old_val.i, new_val.i)); + } + } +}; + + +// Sum reduction for double +struct SumReductionDouble { + using LHS = double; + using RHS = double; + static constexpr RHS identity = 0.0; // ← inside struct, constexpr + + template + static void apply(LHS &lhs, RHS rhs) { + if (EXCLUSIVE) { + lhs += rhs; + } else { + union { double d; long long i; } old_val, new_val; + do { + old_val.d = lhs; + new_val.d = old_val.d + rhs; + } while (!__sync_bool_compare_and_swap( + (long long*)&lhs, old_val.i, new_val.i)); + } + } + + template + static void fold(RHS &rhs1, RHS rhs2) { + if (EXCLUSIVE) { + rhs1 += rhs2; + } else { + union { double d; long long i; } old_val, new_val; + do { + old_val.d = rhs1; + new_val.d = old_val.d + rhs2; + } while (!__sync_bool_compare_and_swap( + (long long*)&rhs1, old_val.i, new_val.i)); + } + } +}; + +// Reduction op IDs — must not conflict with other registered redops +enum SumReductionOpIDs { + REDOP_SUM_FLOAT = 1, + REDOP_SUM_DOUBLE = 2, +}; + +inline Realm::ReductionOpID get_sum_reduction_op_id(DataType dtype) { + switch (dtype) { + case DataType::FLOAT: return REDOP_SUM_FLOAT; + case DataType::DOUBLE: return REDOP_SUM_DOUBLE; + default: + PANIC("no sum reduction registered for datatype {}", dtype); + } +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc b/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc new file mode 100644 index 0000000000..d1fc941007 --- /dev/null +++ b/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc @@ -0,0 +1,450 @@ +#include "internal/realm_test_utils.h" +#include "kernels/allocation.h" +#include "kernels/compare_tensor_accessors.h" +#include "kernels/copy_tensor_accessor.h" +#include "kernels/format_accessor_contents.h" +#include "kernels/tensor_accessor_reductions.h" +#include "op-attrs/operator_task_space_to_operator_task_space_mapping.h" +#include "op-attrs/ops/element_unary.h" +#include "op-attrs/ops/linear.h" +#include "op-attrs/ops/replicate.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_shape.dtg.h" +#include "op-attrs/tensor_slot_name.dtg.h" +#include "pcg/device_type.dtg.h" +#include "pcg/machine_space_coordinate.dtg.h" +#include "pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" +#include "realm-execution/distributed_ff_handle.h" +#include "realm-execution/dynamic_tensor_accessor_from_instance.h" +#include "realm-execution/pcg_instance.h" +#include "realm-execution/realm_context.h" +#include "realm-execution/realm_manager.h" +#include "task-spec/permissions.h" +#include "test/utils/doctest/check_kv.h" +#include "utils/containers/require_only_key.h" +#include + +namespace test { + +using namespace ::FlexFlow; +namespace Realm = ::FlexFlow::Realm; + +template +static ParallelLayerAttrs make_layer_attrs(T const &op_attrs) { + return ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{op_attrs}, + /*name=*/std::nullopt, + }; +}; + +static bool did_loss_decrease(GenericTensorAccessorR const &first_epoch, + GenericTensorAccessorR const &last_epoch, + Allocator &allocator) { + return tensor_accessor_all( + compare_tensor_accessors_le(last_epoch, first_epoch, allocator)); +} + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("RealmBackend e2e Training Replicate Op (CPU Model Parallelism)") { + std::vector fake_args = + make_fake_realm_args(/*num_cpus=*/2_p, /*num_gpus=*/0_n); + int fake_argc = fake_args.size(); + char **fake_argv = fake_args.data(); + + RealmManager manager = RealmManager{&fake_argc, &fake_argv}; + ControllerTaskResult result = manager.start_controller([](RealmContext + &ctx) { + Allocator allocator = ctx.get_current_device_allocator(); + + positive_int batch_size = 10_p; + positive_int data_dim = 16_p; + positive_int hidden_dim = 32_p; + positive_int output_dim = 1_p; + + // 10,2 + TensorShape output_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; + + // 10,2 + TensorShape label_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; + + GenericTensorAccessorW label_tensor = + allocator.allocate_tensor(label_tensor_shape); + + // construct computation graph + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + // input tensor + // 10, 16 + TensorShape input_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, data_dim}}, DataType::FLOAT}; + + // parallel layer -> input tensor + ParallelLayerAddedResult inputs_layer = + pcg_add_input_layer(pcg, input_tensor_shape); + parallel_tensor_guid_t t_input = + require_only_key(inputs_layer.outputs, TensorSlotName::OUTPUT); + + // parallel layer -> input tensor 2 + ParallelLayerAddedResult inputs_layer_2 = + pcg_add_input_layer(pcg, input_tensor_shape); + parallel_tensor_guid_t t_input_2 = + require_only_key(inputs_layer_2.outputs, TensorSlotName::OUTPUT); + + // binary ADD attribute + ElementBinaryAttrs add_attrs = ElementBinaryAttrs{ + OperatorType::EW_ADD, + DataType::FLOAT, + false, + false, + }; + + // parallel layer -> perform add + ParallelLayerAddedResult add_operator_1 = + add_parallel_layer(pcg, make_layer_attrs(add_attrs), + { + { + TensorSlotName::LHS_INPUT, + t_input, + }, + { + TensorSlotName::RHS_INPUT, + t_input_2, + }, + }, + {/* weight */}); + + parallel_tensor_guid_t t_add_1 = + require_only_key(add_operator_1.outputs, TensorSlotName::OUTPUT); + + // parallel layer -> perform replicate + const positive_int replicate_degree = 2_p; + ReplicateAttrs repl_attrs = ReplicateAttrs(replicate_degree); + ParallelLayerAddedResult repl_operator_1 = + add_parallel_layer(pcg, make_layer_attrs(repl_attrs), + { + { + TensorSlotName::INPUT, + t_add_1, + }, + }, + /*weight=*/{}); + // output of replicate layer + parallel_tensor_guid_t t_repl_1 = + require_only_key(repl_operator_1.outputs, TensorSlotName::OUTPUT); + + // parallel layer -> perform RelU + ParallelLayerAddedResult relu_operator_1 = + add_parallel_layer(pcg, make_layer_attrs(make_relu_attrs()), + /*inputs=*/ + { + { + TensorSlotName::INPUT, + t_repl_1, + }, + }, + /*weights=*/{}); + // output of relu layer + parallel_tensor_guid_t t_relu_1 = + require_only_key(relu_operator_1.outputs, TensorSlotName::OUTPUT); + + // machine + MachineSpaceCoordinate cpu0{0_n, 0_n, DeviceType::CPU}; + MachineSpaceCoordinate cpu1{0_n, 1_n, DeviceType::CPU}; + + ParallelTensorSpaceCoordinate tensor_coord0{ + /* sum_component */ 0_n, /* discard_copy_component */ 0_n, + /*shard_component*/ FFOrdered{0_n}}; + ParallelTensorSpaceCoordinate tensor_coord1{ + /* sum_component */ 0_n, /* discard_copy_component */ 1_n, + /*shard_component*/ FFOrdered{0_n}}; + MappedParallelComputationGraph mpcg{ + pcg, + {{inputs_layer.parallel_layer, + MappedOperatorTaskGroup{ + {{cpu0, OperatorAtomicTaskShardBinding{{{TensorSlotName::OUTPUT, + tensor_coord0}}}}}}}, + {inputs_layer_2.parallel_layer, + MappedOperatorTaskGroup{ + {{cpu0, OperatorAtomicTaskShardBinding{{{TensorSlotName::OUTPUT, + tensor_coord0}}}}}}}, + {add_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + {{cpu0, OperatorAtomicTaskShardBinding{{ + {TensorSlotName::LHS_INPUT, tensor_coord0}, + {TensorSlotName::RHS_INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}}}}, + {repl_operator_1.parallel_layer, + MappedOperatorTaskGroup{{ + {cpu0, OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}, + {cpu1, OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord1}, + }}}, + }}}, + {relu_operator_1.parallel_layer, + MappedOperatorTaskGroup{{ + {cpu0, OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}, + {cpu1, OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord1}, + {TensorSlotName::OUTPUT, tensor_coord1}, + }}}, + }}}}, + }; + + MappedOperatorTaskGroup loss_mapping{ + {{cpu0, OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::LOGIT, tensor_coord0}, + }}}}}; + + // instantiate computation graph + LossAttrs loss_attrs = LossAttrs{ + NonconfigurableLossAttrs{LossFunction::CATEGORICAL_CROSSENTROPY}}; + OptimizerAttrs optimizer_attrs = + OptimizerAttrs{SGDOptimizerAttrs{/*lr=*/0.001, + /*momentum=*/0.9, + /*nesterov=*/false, + /*weight_decay=*/0.001}}; + + std::unordered_map + input_tensors; + + DistributedFfHandle device_handle = + create_distributed_ff_handle(ctx, + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true); + PCGInstance pcg_instance = create_pcg_instance( + /*ctx=*/ctx, + /*mpcg=*/mpcg, + /*optimizer=*/optimizer_attrs, + /*loss=*/std::nullopt, + /*input_tensors=*/input_tensors, + /*profiling_settings=*/ProfilingSettings{0, 0}, + /*device_handle=*/device_handle, + /*iteration_config=*/FFIterationConfig{1_p}); + + // begin training loop + int num_epochs = 1; + for (int i = 0; i < num_epochs; i++) { + perform_all_passes_for_pcg_instance( + /*instance=*/pcg_instance, + /*profiling_settings=*/ProfilingSettings{0, 0}, + /*device_handle=*/device_handle, + /*iteration_config=*/FFIterationConfig{1_p}); + } + }); + result.wait(); + } +} + +TEST_SUITE(FF_CUDA_TEST_SUITE) { + TEST_CASE("RealmBackend e2e Training Replicate Op (GPU Model Parallelism)") { + std::vector fake_args = + make_fake_realm_args(/*num_cpus=*/1_p, /*num_gpus=*/2_n); + int fake_argc = fake_args.size(); + char **fake_argv = fake_args.data(); + + RealmManager manager = RealmManager{&fake_argc, &fake_argv}; + + ControllerTaskResult result = + manager.start_controller([](RealmContext &ctx) { + Allocator allocator = ctx.get_current_device_allocator(); + + positive_int batch_size = 10_p; + positive_int data_dim = 16_p; + positive_int hidden_dim = 32_p; + positive_int output_dim = 1_p; + + // 10,2 + TensorShape output_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; + + // 10,2 + TensorShape label_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; + + GenericTensorAccessorW label_tensor = + allocator.allocate_tensor(label_tensor_shape); + + // construct computation graph + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + // input tensor + // 10, 16 + TensorShape input_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, data_dim}}, DataType::FLOAT}; + + // parallel layer -> input tensor + ParallelLayerAddedResult inputs_layer = + pcg_add_input_layer(pcg, input_tensor_shape); + parallel_tensor_guid_t t_input = + require_only_key(inputs_layer.outputs, TensorSlotName::OUTPUT); + + // parallel layer -> input tensor 2 + ParallelLayerAddedResult inputs_layer_2 = + pcg_add_input_layer(pcg, input_tensor_shape); + parallel_tensor_guid_t t_input_2 = + require_only_key(inputs_layer_2.outputs, TensorSlotName::OUTPUT); + + // binary ADD attribute + ElementBinaryAttrs add_attrs = ElementBinaryAttrs{ + OperatorType::EW_ADD, + DataType::FLOAT, + false, + false, + }; + + // parallel layer -> perform add + ParallelLayerAddedResult add_operator_1 = + add_parallel_layer(pcg, make_layer_attrs(add_attrs), + { + { + TensorSlotName::LHS_INPUT, + t_input, + }, + { + TensorSlotName::RHS_INPUT, + t_input_2, + }, + }, + {/* weight */}); + + parallel_tensor_guid_t t_add_1 = + require_only_key(add_operator_1.outputs, TensorSlotName::OUTPUT); + + // parallel layer -> perform replicate + const positive_int replicate_degree = 2_p; + ReplicateAttrs repl_attrs = ReplicateAttrs(replicate_degree); + ParallelLayerAddedResult repl_operator_1 = + add_parallel_layer(pcg, make_layer_attrs(repl_attrs), + { + { + TensorSlotName::INPUT, + t_add_1, + }, + }, + /*weight=*/{}); + // output of replicate layer + parallel_tensor_guid_t t_repl_1 = + require_only_key(repl_operator_1.outputs, TensorSlotName::OUTPUT); + + // parallel layer -> perform RelU + ParallelLayerAddedResult relu_operator_1 = + add_parallel_layer(pcg, make_layer_attrs(make_relu_attrs()), + /*inputs=*/ + { + { + TensorSlotName::INPUT, + t_repl_1, + }, + }, + /*weights=*/{}); + // output of relu layer + parallel_tensor_guid_t t_relu_1 = + require_only_key(relu_operator_1.outputs, TensorSlotName::OUTPUT); + + // machine + MachineSpaceCoordinate gpu0{0_n, 0_n, DeviceType::GPU}; + MachineSpaceCoordinate gpu1{0_n, 1_n, DeviceType::GPU}; + ParallelTensorSpaceCoordinate tensor_coord0{0_n, 0_n, FFOrdered{0_n}}; + ParallelTensorSpaceCoordinate tensor_coord1{0_n, 1_n, FFOrdered{0_n}}; + MappedParallelComputationGraph mpcg{ + pcg, + { + {inputs_layer.parallel_layer, + MappedOperatorTaskGroup{ + {{gpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, + {inputs_layer_2.parallel_layer, + MappedOperatorTaskGroup{ + {{gpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, + {add_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + {{gpu0, OperatorAtomicTaskShardBinding{{ + {TensorSlotName::LHS_INPUT, tensor_coord0}, + {TensorSlotName::RHS_INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}}}}, + {repl_operator_1.parallel_layer, + MappedOperatorTaskGroup{{ + {gpu0, OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}, + {gpu1, OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord1}, + }}}}}}, + {relu_operator_1.parallel_layer, + MappedOperatorTaskGroup{{ + {gpu0, OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}, + {gpu1, OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord1}, + {TensorSlotName::OUTPUT, tensor_coord1}, + }}}, + }}}, + }, + }; + + MappedOperatorTaskGroup loss_mapping{ + {{gpu0, OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::LOGIT, tensor_coord0}, + }}}}}; + + // instantiate computation graph + LossAttrs loss_attrs = LossAttrs{ + NonconfigurableLossAttrs{LossFunction::CATEGORICAL_CROSSENTROPY}}; + OptimizerAttrs optimizer_attrs = + OptimizerAttrs{SGDOptimizerAttrs{/*lr=*/0.001, + /*momentum=*/0.9, + /*nesterov=*/false, + /*weight_decay=*/0.001}}; + + std::unordered_map + input_tensors; + + DistributedFfHandle device_handle = create_distributed_ff_handle( + ctx, + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true); + + PCGInstance pcg_instance = create_pcg_instance( + /*ctx=*/ctx, + /*mpcg=*/mpcg, + /*optimizer=*/optimizer_attrs, + /*loss=*/std::nullopt, + /*input_tensors=*/input_tensors, + /*profiling_settings=*/ProfilingSettings{0, 0}, + /*device_handle=*/device_handle, + /*iteration_config=*/FFIterationConfig{1_p}); + + // begin training loop + int num_epochs = 1; + for (int i = 0; i < num_epochs; i++) { + perform_all_passes_for_pcg_instance( + /*instance=*/pcg_instance, + /*profiling_settings=*/ProfilingSettings{0, 0}, + /*device_handle=*/device_handle, + /*iteration_config=*/FFIterationConfig{1_p}); + } + }); + result.wait(); + } +} +} // namespace test From 34056217cbb4a8067e582a792fa8af726c8d712e Mon Sep 17 00:00:00 2001 From: Seema Mirchandaney Date: Thu, 9 Apr 2026 15:52:21 -0700 Subject: [PATCH 02/19] Add support for replicate op in distributed training - Add perform_pass_expansion_for_replicate for fwd/bwd pass expansion - Add perform_shard_expansion_for_replicate and _bwd for shard expansion - Add build_replicate_invocation in make_dynamic_open_dataflow_graph - Add is_replicate_attrs helper and guard replicate in copy_insertion - Add ReplicateAttrs to TrainingOperationAttrs - Add SumReductionFloat/Double for backward replicate reduce operation - Add issue_replicate_bwd in spawn_dynamic_node_invocation - Fix per_device_op_state init race condition with direct write - Fix .value() calls on optional per_device_op_state across op impls - Update issue_copy to support optional reduction op - Add testcase for replicate op --- .../src/op-attrs/ops/element_unary.cc | 1 - .../test/src/op-attrs/ops/element_unary.cc | 8 - .../include/realm-execution/realm_context.h | 19 +- .../include/realm-execution/sum_reduction.h | 99 ---- .../realm-execution/tasks/realm_reduction.h | 49 +- ...uted_per_device_op_state_initialization.cc | 6 +- .../src/realm-execution/pcg_instance.cc | 54 +++ .../src/realm-execution/realm_context.cc | 9 +- .../impl/per_device_op_state_init_task.cc | 16 +- .../tasks/realm_task_registry.cc | 10 + .../src/realm-execution/test_op_replicate.cc | 444 +++++++++--------- .../training_operation_attrs.dtg.toml | 4 + .../task-spec/dynamic_graph/copy_insertion.cc | 47 +- ...mic_open_dataflow_graph_from_mapped_pcg.cc | 127 +++++ .../task-spec/dynamic_graph/pass_expansion.cc | 43 ++ .../dynamic_graph/shard_expansion.cc | 125 ++++- .../src/task-spec/ops/impl/element_binary.cc | 8 +- .../src/task-spec/ops/impl/element_unary.cc | 8 +- 18 files changed, 713 insertions(+), 364 deletions(-) delete mode 100644 lib/realm-execution/include/realm-execution/sum_reduction.h diff --git a/lib/op-attrs/src/op-attrs/ops/element_unary.cc b/lib/op-attrs/src/op-attrs/ops/element_unary.cc index 9d02923689..ca7e417814 100644 --- a/lib/op-attrs/src/op-attrs/ops/element_unary.cc +++ b/lib/op-attrs/src/op-attrs/ops/element_unary.cc @@ -35,7 +35,6 @@ ParallelTensorDimDegrees get_output_parallel_dim_degrees( ElementUnaryAttrs const &attrs, ParallelTensorDimDegrees const &input_degrees) { ASSERT(input_degrees.sum_degree.value == 1); - ASSERT(input_degrees.discard_copy_degree.value == 1); return input_degrees; } diff --git a/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc b/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc index 672b160cbd..43b4be06d8 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc @@ -62,13 +62,5 @@ TEST_SUITE(FF_TEST_SUITE) { SumDegree{degree}, DiscardCopyDegree{1_p}, 1_p, 1_p, 1_p))); } - SUBCASE("discard copy degree > 1") { - positive_int degree = 2_p; - - CHECK_THROWS(get_output_shape( - attrs, - make_input( - SumDegree{1_p}, DiscardCopyDegree{degree}, 1_p, 1_p, 1_p))); - } } } diff --git a/lib/realm-execution/include/realm-execution/realm_context.h b/lib/realm-execution/include/realm-execution/realm_context.h index ab89e916c0..eab42d0d79 100644 --- a/lib/realm-execution/include/realm-execution/realm_context.h +++ b/lib/realm-execution/include/realm-execution/realm_context.h @@ -63,15 +63,18 @@ struct RealmContext { int priority = 0); ///\} - /** \name Data movement */ + /** \name Data movement and reduction */ ///\{ - Realm::Event issue_copy(ParallelTensorShape const &src_shape, - Realm::RegionInstance src_inst, - ParallelTensorShape const &dst_shape, - Realm::RegionInstance dst_inst, - Realm::ProfilingRequestSet const &requests, - Realm::Event wait_on = Realm::Event::NO_EVENT, - int priority = 0); + Realm::Event + issue_copy(ParallelTensorShape const &src_shape, + Realm::RegionInstance src_inst, + ParallelTensorShape const &dst_shape, + Realm::RegionInstance dst_inst, + Realm::ProfilingRequestSet const &requests, + Realm::Event wait_on = Realm::Event::NO_EVENT, + int priority = 0, + std::optional redop_id = std::nullopt, + bool exclusive = false); ///\} /** \name Instance management */ diff --git a/lib/realm-execution/include/realm-execution/sum_reduction.h b/lib/realm-execution/include/realm-execution/sum_reduction.h deleted file mode 100644 index b845b5b7f2..0000000000 --- a/lib/realm-execution/include/realm-execution/sum_reduction.h +++ /dev/null @@ -1,99 +0,0 @@ -#pragma once -#include -#include "op-attrs/datatype.dtg.h" - -namespace FlexFlow { - -// Sum reduction for float -struct SumReductionFloat { - using LHS = float; - using RHS = float; - static const RHS identity; - - template - static void apply(LHS &lhs, RHS rhs) { - if (EXCLUSIVE) { - lhs += rhs; - } else { - // atomic add for non-exclusive - __sync_fetch_and_add((int*)&lhs, *(int*)&rhs); - // proper float atomic add — use union trick - union { float f; int i; } old_val, new_val; - do { - old_val.f = lhs; - new_val.f = old_val.f + rhs; - } while (!__sync_bool_compare_and_swap( - (int*)&lhs, old_val.i, new_val.i)); - } - } - - template - static void fold(RHS &rhs1, RHS rhs2) { - if (EXCLUSIVE) { - rhs1 += rhs2; - } else { - union { float f; int i; } old_val, new_val; - do { - old_val.f = rhs1; - new_val.f = old_val.f + rhs2; - } while (!__sync_bool_compare_and_swap( - (int*)&rhs1, old_val.i, new_val.i)); - } - } -}; - -const SumReductionFloat::RHS SumReductionFloat::identity = 0.0f; - -// Sum reduction for double -struct SumReductionDouble { - using LHS = double; - using RHS = double; - static const RHS identity; - - template - static void apply(LHS &lhs, RHS rhs) { - if (EXCLUSIVE) { - lhs += rhs; - } else { - union { double d; long long i; } old_val, new_val; - do { - old_val.d = lhs; - new_val.d = old_val.d + rhs; - } while (!__sync_bool_compare_and_swap( - (long long*)&lhs, old_val.i, new_val.i)); - } - } - - template - static void fold(RHS &rhs1, RHS rhs2) { - if (EXCLUSIVE) { - rhs1 += rhs2; - } else { - union { double d; long long i; } old_val, new_val; - do { - old_val.d = rhs1; - new_val.d = old_val.d + rhs2; - } while (!__sync_bool_compare_and_swap( - (long long*)&rhs1, old_val.i, new_val.i)); - } - } -}; - -const SumReductionDouble::RHS SumReductionDouble::identity = 0.0; - -// Reduction op IDs — must not conflict with other registered redops -enum SumReductionOpIDs { - REDOP_SUM_FLOAT = 1, - REDOP_SUM_DOUBLE = 2, -}; - -inline Realm::ReductionOpID get_sum_reduction_op_id(DataType dtype) { - switch (dtype) { - case DataType::FLOAT: return REDOP_SUM_FLOAT; - case DataType::DOUBLE: return REDOP_SUM_DOUBLE; - default: - PANIC("no sum reduction registered for datatype {}", dtype); - } -} - -} // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/tasks/realm_reduction.h b/lib/realm-execution/include/realm-execution/tasks/realm_reduction.h index d1d6e1d880..d9cf00441b 100644 --- a/lib/realm-execution/include/realm-execution/tasks/realm_reduction.h +++ b/lib/realm-execution/include/realm-execution/tasks/realm_reduction.h @@ -1,6 +1,6 @@ #pragma once -#include #include "op-attrs/datatype.dtg.h" +#include namespace FlexFlow { @@ -8,7 +8,7 @@ namespace FlexFlow { struct SumReductionFloat { using LHS = float; using RHS = float; - static constexpr RHS identity = 0.0f; // ← inside struct, constexpr + static constexpr RHS identity = 0.0f; // ← inside struct, constexpr template static void apply(LHS &lhs, RHS rhs) { @@ -16,14 +16,17 @@ struct SumReductionFloat { lhs += rhs; } else { // atomic add for non-exclusive - __sync_fetch_and_add((int*)&lhs, *(int*)&rhs); + __sync_fetch_and_add((int *)&lhs, *(int *)&rhs); // proper float atomic add — use union trick - union { float f; int i; } old_val, new_val; + union { + float f; + int i; + } old_val, new_val; do { old_val.f = lhs; new_val.f = old_val.f + rhs; - } while (!__sync_bool_compare_and_swap( - (int*)&lhs, old_val.i, new_val.i)); + } while ( + !__sync_bool_compare_and_swap((int *)&lhs, old_val.i, new_val.i)); } } @@ -32,34 +35,39 @@ struct SumReductionFloat { if (EXCLUSIVE) { rhs1 += rhs2; } else { - union { float f; int i; } old_val, new_val; + union { + float f; + int i; + } old_val, new_val; do { old_val.f = rhs1; new_val.f = old_val.f + rhs2; - } while (!__sync_bool_compare_and_swap( - (int*)&rhs1, old_val.i, new_val.i)); + } while ( + !__sync_bool_compare_and_swap((int *)&rhs1, old_val.i, new_val.i)); } } }; - // Sum reduction for double struct SumReductionDouble { using LHS = double; using RHS = double; - static constexpr RHS identity = 0.0; // ← inside struct, constexpr + static constexpr RHS identity = 0.0; // ← inside struct, constexpr template static void apply(LHS &lhs, RHS rhs) { if (EXCLUSIVE) { lhs += rhs; } else { - union { double d; long long i; } old_val, new_val; + union { + double d; + long long i; + } old_val, new_val; do { old_val.d = lhs; new_val.d = old_val.d + rhs; } while (!__sync_bool_compare_and_swap( - (long long*)&lhs, old_val.i, new_val.i)); + (long long *)&lhs, old_val.i, new_val.i)); } } @@ -68,26 +76,31 @@ struct SumReductionDouble { if (EXCLUSIVE) { rhs1 += rhs2; } else { - union { double d; long long i; } old_val, new_val; + union { + double d; + long long i; + } old_val, new_val; do { old_val.d = rhs1; new_val.d = old_val.d + rhs2; } while (!__sync_bool_compare_and_swap( - (long long*)&rhs1, old_val.i, new_val.i)); + (long long *)&rhs1, old_val.i, new_val.i)); } } }; // Reduction op IDs — must not conflict with other registered redops enum SumReductionOpIDs { - REDOP_SUM_FLOAT = 1, + REDOP_SUM_FLOAT = 1, REDOP_SUM_DOUBLE = 2, }; inline Realm::ReductionOpID get_sum_reduction_op_id(DataType dtype) { switch (dtype) { - case DataType::FLOAT: return REDOP_SUM_FLOAT; - case DataType::DOUBLE: return REDOP_SUM_DOUBLE; + case DataType::FLOAT: + return REDOP_SUM_FLOAT; + case DataType::DOUBLE: + return REDOP_SUM_DOUBLE; default: PANIC("no sum reduction registered for datatype {}", dtype); } diff --git a/lib/realm-execution/src/realm-execution/distributed_per_device_op_state_initialization.cc b/lib/realm-execution/src/realm-execution/distributed_per_device_op_state_initialization.cc index 1d517a8fe4..e7d8647b12 100644 --- a/lib/realm-execution/src/realm-execution/distributed_per_device_op_state_initialization.cc +++ b/lib/realm-execution/src/realm-execution/distributed_per_device_op_state_initialization.cc @@ -31,6 +31,7 @@ PerDeviceOpStateBacking perform_distributed_per_device_op_state_initialization( std::unordered_map *> device_state_map; + std::vector completion_events; for (DynamicNodeInvocation const &invocation : dg.invocations) { Realm::Processor target_proc = ctx.map_device_coord_to_processor( assert_unwrap(invocation.node_attrs.device_coord)); @@ -56,6 +57,7 @@ PerDeviceOpStateBacking perform_distributed_per_device_op_state_initialization( precondition); if (completion_event.has_value()) { + completion_events.push_back(completion_event.value()); device_state_map.insert(std::pair{invocation, device_state_ptr}); } else { // Task doesn't require initialization, clean up and don't store result @@ -63,7 +65,9 @@ PerDeviceOpStateBacking perform_distributed_per_device_op_state_initialization( } } - ctx.get_outstanding_events().wait(); + // wait for all init tasks — direct write to *result_ptr happens + // before each init task event fires so result is ready after this + Realm::Event::merge_events(completion_events).wait(); auto deref = [](DeviceSpecificPtr *const &p) { return *p; }; std::unordered_map> diff --git a/lib/realm-execution/src/realm-execution/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance.cc index 0ecd02143e..a0653c3c37 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance.cc @@ -6,6 +6,7 @@ #include "realm-execution/instance_allocation.h" #include "realm-execution/realm_context.h" #include "realm-execution/tasks/impl/op_task.h" +#include "realm-execution/tasks/realm_reduction.h" #include "realm-execution/tensor_instance_backing.h" #include "task-spec/dynamic_graph/copy_insertion.h" #include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" @@ -215,6 +216,46 @@ static Realm::Event spawn_dynamic_node_invocation( precondition); }; + // issue_replicate_bwd lambda + auto issue_replicate_bwd = [&]() { + std::optional output_grad_opt; + for (auto const &[slot, value] : invocation.inputs) { + if (slot.slot_tensor_role == DynamicTensorRole{FwbTensorType::GRADIENT}) { + output_grad_opt = value; + } + } + DynamicValueAttrs output_grad = assert_unwrap(output_grad_opt); + DynamicValueAttrs input_grad = get_only(invocation.outputs).second; + Realm::RegionInstance dst_inst = + tensor_instance_backing.backing.at(input_grad).first; + + Realm::ReductionOpID redop_id = get_sum_reduction_op_id( + assert_unwrap(output_grad.parallel_tensor_shape).data_type); + + // chain reductions sequentially to avoid write races on dst + Realm::Event e = precondition; + for (auto const &[p, m] : assert_unwrap(output_grad.mapping)) { + DynamicValueAttrs replica_key = output_grad; + replica_key.mapping = + bidict{{p, m}}; + replica_key.shard_coord = p; + + Realm::RegionInstance src_inst = + tensor_instance_backing.backing.at(replica_key).first; + + e = ctx.issue_copy(assert_unwrap(output_grad.parallel_tensor_shape), + src_inst, + assert_unwrap(input_grad.parallel_tensor_shape), + dst_inst, + Realm::ProfilingRequestSet{}, + e, + 0, + redop_id, + false); + } + return e; + }; + TrainingOperationAttrs op_attrs = assert_unwrap(invocation.node_attrs.op_attrs); return op_attrs.visit(overload{ @@ -222,11 +263,24 @@ static Realm::Event spawn_dynamic_node_invocation( return pcg_op_attrs.visit(overload{ [&](InputAttrs const &) { return Realm::Event::NO_EVENT; }, [&](WeightAttrs const &) { return Realm::Event::NO_EVENT; }, + [&](ReplicateAttrs const &) { + // this should never be reached since replicate + // goes through TrainingOperationAttrs::ReplicateAttrs + PANIC("unexpected replicate in PCGOperatorAttrs path"); + return Realm::Event::NO_EVENT; + }, [&](auto const &) { return spawn_task(); }, }); }, [&](LossAttrs const &) { return spawn_task(); }, [&](CopyAttrs const &) { return issue_copy(); }, + [&](ReplicateAttrs const &) { + if (invocation.node_attrs.task_type.has_value() && + invocation.node_attrs.task_type.value() == DynamicTaskType::BWD) { + return issue_replicate_bwd(); + } + return issue_copy(); + }, }); } diff --git a/lib/realm-execution/src/realm-execution/realm_context.cc b/lib/realm-execution/src/realm-execution/realm_context.cc index 790c1bd613..a4669bf43e 100644 --- a/lib/realm-execution/src/realm-execution/realm_context.cc +++ b/lib/realm-execution/src/realm-execution/realm_context.cc @@ -161,7 +161,9 @@ Realm::Event Realm::RegionInstance dst_inst, Realm::ProfilingRequestSet const &requests, Realm::Event wait_on, - int priority) { + int priority, + std::optional redop_id, + bool exclusive) { TensorShape src_piece_shape = get_piece_shape(src_shape); TensorShape dst_piece_shape = get_piece_shape(dst_shape); ASSERT(src_piece_shape == dst_piece_shape); // For now, assume they match @@ -183,6 +185,11 @@ Realm::Event size_of_datatype(src_piece_shape.data_type).int_from_positive_int()), /*subfield_offset=*/0); + // set reduction op on dst field if provided + if (redop_id.has_value()) { + dst_field.set_redop(redop_id.value(), /*is_fold=*/false, exclusive); + } + Realm::Event result; switch (src_piece_shape.dims.ff_ordered.num_dims()) { #if REALM_MAX_DIM >= 1 diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/per_device_op_state_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/per_device_op_state_init_task.cc index 753fccf74b..0ea51810e4 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/per_device_op_state_init_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/per_device_op_state_init_task.cc @@ -66,11 +66,17 @@ void per_device_op_state_init_task_body(void const *args, result_state, ctx.get_current_device_idx())}; DeviceSpecificPtr result_device_specific{ ctx.get_current_device_idx(), result_state_ptr}; - spawn_per_device_op_state_init_return_task(ctx, - task_args.origin_proc, - result_device_specific, - task_args.origin_result_ptr, - Realm::Event::NO_EVENT); + + // replace spawn_per_device_op_state_init_return_task with: + // NOTE: SM/TODO: direct write assumes single-node shared address space + // For multi-node, replace with UserEvent trigger pattern + *task_args.origin_result_ptr = result_device_specific; + + // spawn_per_device_op_state_init_return_task(ctx, + // task_args.origin_proc, + // result_device_specific, + // task_args.origin_result_ptr, + // Realm::Event::NO_EVENT); } std::optional spawn_per_device_op_state_init_task( diff --git a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc index e7a8948f8d..acafdf59fd 100644 --- a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc +++ b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc @@ -5,6 +5,7 @@ #include "realm-execution/tasks/impl/op_task.h" #include "realm-execution/tasks/impl/per_device_op_state_init_return_task.h" #include "realm-execution/tasks/impl/per_device_op_state_init_task.h" +#include "realm-execution/tasks/realm_reduction.h" #include "realm-execution/tasks/task_id_t.h" #include "utils/exception.h" @@ -30,9 +31,18 @@ Realm::Event register_task(Realm::Processor::Kind target_kind, Realm::ProfilingRequestSet()); } +static void register_reductions() { + // register sum reduction ops + Realm::Runtime rt = Realm::Runtime::get_runtime(); + rt.register_reduction(REDOP_SUM_FLOAT); + rt.register_reduction(REDOP_SUM_DOUBLE); + // register_reduction is synchronous — no event returned +} + Realm::Event register_all_tasks() { std::vector pending_registrations; + register_reductions(); std::vector init_task_ids = { // Init tasks task_id_t::BATCHNORM_INIT_TASK_ID, diff --git a/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc b/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc index d1fc941007..632f08d239 100644 --- a/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc +++ b/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc @@ -56,194 +56,207 @@ TEST_SUITE(FF_TEST_SUITE) { char **fake_argv = fake_args.data(); RealmManager manager = RealmManager{&fake_argc, &fake_argv}; - ControllerTaskResult result = manager.start_controller([](RealmContext - &ctx) { - Allocator allocator = ctx.get_current_device_allocator(); - - positive_int batch_size = 10_p; - positive_int data_dim = 16_p; - positive_int hidden_dim = 32_p; - positive_int output_dim = 1_p; - - // 10,2 - TensorShape output_tensor_shape = TensorShape{ - TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; - - // 10,2 - TensorShape label_tensor_shape = TensorShape{ - TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; - - GenericTensorAccessorW label_tensor = - allocator.allocate_tensor(label_tensor_shape); - - // construct computation graph - ParallelComputationGraph pcg = empty_parallel_computation_graph(); - - // input tensor - // 10, 16 - TensorShape input_tensor_shape = TensorShape{ - TensorDims{FFOrdered{batch_size, data_dim}}, DataType::FLOAT}; - - // parallel layer -> input tensor - ParallelLayerAddedResult inputs_layer = - pcg_add_input_layer(pcg, input_tensor_shape); - parallel_tensor_guid_t t_input = - require_only_key(inputs_layer.outputs, TensorSlotName::OUTPUT); - - // parallel layer -> input tensor 2 - ParallelLayerAddedResult inputs_layer_2 = - pcg_add_input_layer(pcg, input_tensor_shape); - parallel_tensor_guid_t t_input_2 = - require_only_key(inputs_layer_2.outputs, TensorSlotName::OUTPUT); - - // binary ADD attribute - ElementBinaryAttrs add_attrs = ElementBinaryAttrs{ - OperatorType::EW_ADD, - DataType::FLOAT, - false, - false, - }; - - // parallel layer -> perform add - ParallelLayerAddedResult add_operator_1 = - add_parallel_layer(pcg, make_layer_attrs(add_attrs), - { - { - TensorSlotName::LHS_INPUT, - t_input, - }, + ControllerTaskResult result = + manager.start_controller([](RealmContext &ctx) { + Allocator allocator = ctx.get_current_device_allocator(); + + positive_int batch_size = 10_p; + positive_int data_dim = 16_p; + positive_int hidden_dim = 32_p; + positive_int output_dim = 1_p; + + // 10,2 + TensorShape output_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; + + // 10,2 + TensorShape label_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; + + GenericTensorAccessorW label_tensor = + allocator.allocate_tensor(label_tensor_shape); + + // construct computation graph + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + // input tensor + // 10, 16 + TensorShape input_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, data_dim}}, DataType::FLOAT}; + + // parallel layer -> input tensor + ParallelLayerAddedResult inputs_layer = + pcg_add_input_layer(pcg, input_tensor_shape); + parallel_tensor_guid_t t_input = + require_only_key(inputs_layer.outputs, TensorSlotName::OUTPUT); + + // parallel layer -> input tensor 2 + ParallelLayerAddedResult inputs_layer_2 = + pcg_add_input_layer(pcg, input_tensor_shape); + parallel_tensor_guid_t t_input_2 = + require_only_key(inputs_layer_2.outputs, TensorSlotName::OUTPUT); + + // binary ADD attribute + ElementBinaryAttrs add_attrs = ElementBinaryAttrs{ + OperatorType::EW_ADD, + DataType::FLOAT, + false, + false, + }; + + // parallel layer -> perform add + ParallelLayerAddedResult add_operator_1 = + add_parallel_layer(pcg, + make_layer_attrs(add_attrs), { - TensorSlotName::RHS_INPUT, - t_input_2, + { + TensorSlotName::LHS_INPUT, + t_input, + }, + { + TensorSlotName::RHS_INPUT, + t_input_2, + }, }, - }, - {/* weight */}); - - parallel_tensor_guid_t t_add_1 = - require_only_key(add_operator_1.outputs, TensorSlotName::OUTPUT); - - // parallel layer -> perform replicate - const positive_int replicate_degree = 2_p; - ReplicateAttrs repl_attrs = ReplicateAttrs(replicate_degree); - ParallelLayerAddedResult repl_operator_1 = - add_parallel_layer(pcg, make_layer_attrs(repl_attrs), - { + {/* weight */}); + + parallel_tensor_guid_t t_add_1 = + require_only_key(add_operator_1.outputs, TensorSlotName::OUTPUT); + + // parallel layer -> perform replicate + const positive_int replicate_degree = 2_p; + ReplicateAttrs repl_attrs = ReplicateAttrs(replicate_degree); + ParallelLayerAddedResult repl_operator_1 = + add_parallel_layer(pcg, + make_layer_attrs(repl_attrs), { - TensorSlotName::INPUT, - t_add_1, + { + TensorSlotName::INPUT, + t_add_1, + }, }, - }, - /*weight=*/{}); - // output of replicate layer - parallel_tensor_guid_t t_repl_1 = - require_only_key(repl_operator_1.outputs, TensorSlotName::OUTPUT); - - // parallel layer -> perform RelU - ParallelLayerAddedResult relu_operator_1 = - add_parallel_layer(pcg, make_layer_attrs(make_relu_attrs()), - /*inputs=*/ - { + /*weight=*/{}); + // output of replicate layer + parallel_tensor_guid_t t_repl_1 = + require_only_key(repl_operator_1.outputs, TensorSlotName::OUTPUT); + + // parallel layer -> perform RelU + ParallelLayerAddedResult relu_operator_1 = + add_parallel_layer(pcg, + make_layer_attrs(make_relu_attrs()), + /*inputs=*/ { - TensorSlotName::INPUT, - t_repl_1, + { + TensorSlotName::INPUT, + t_repl_1, + }, }, - }, - /*weights=*/{}); - // output of relu layer - parallel_tensor_guid_t t_relu_1 = - require_only_key(relu_operator_1.outputs, TensorSlotName::OUTPUT); - - // machine - MachineSpaceCoordinate cpu0{0_n, 0_n, DeviceType::CPU}; - MachineSpaceCoordinate cpu1{0_n, 1_n, DeviceType::CPU}; - - ParallelTensorSpaceCoordinate tensor_coord0{ - /* sum_component */ 0_n, /* discard_copy_component */ 0_n, - /*shard_component*/ FFOrdered{0_n}}; - ParallelTensorSpaceCoordinate tensor_coord1{ - /* sum_component */ 0_n, /* discard_copy_component */ 1_n, - /*shard_component*/ FFOrdered{0_n}}; - MappedParallelComputationGraph mpcg{ - pcg, - {{inputs_layer.parallel_layer, - MappedOperatorTaskGroup{ - {{cpu0, OperatorAtomicTaskShardBinding{{{TensorSlotName::OUTPUT, - tensor_coord0}}}}}}}, - {inputs_layer_2.parallel_layer, - MappedOperatorTaskGroup{ - {{cpu0, OperatorAtomicTaskShardBinding{{{TensorSlotName::OUTPUT, - tensor_coord0}}}}}}}, - {add_operator_1.parallel_layer, - MappedOperatorTaskGroup{ - {{cpu0, OperatorAtomicTaskShardBinding{{ - {TensorSlotName::LHS_INPUT, tensor_coord0}, - {TensorSlotName::RHS_INPUT, tensor_coord0}, - {TensorSlotName::OUTPUT, tensor_coord0}, - }}}}}}, - {repl_operator_1.parallel_layer, - MappedOperatorTaskGroup{{ - {cpu0, OperatorAtomicTaskShardBinding{{ - {TensorSlotName::OUTPUT, tensor_coord0}, - }}}, - {cpu1, OperatorAtomicTaskShardBinding{{ - {TensorSlotName::OUTPUT, tensor_coord1}, - }}}, - }}}, - {relu_operator_1.parallel_layer, - MappedOperatorTaskGroup{{ - {cpu0, OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord0}, - {TensorSlotName::OUTPUT, tensor_coord0}, - }}}, - {cpu1, OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord1}, - {TensorSlotName::OUTPUT, tensor_coord1}, - }}}, - }}}}, - }; - - MappedOperatorTaskGroup loss_mapping{ - {{cpu0, OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord0}, - {TensorSlotName::LOGIT, tensor_coord0}, - }}}}}; - - // instantiate computation graph - LossAttrs loss_attrs = LossAttrs{ - NonconfigurableLossAttrs{LossFunction::CATEGORICAL_CROSSENTROPY}}; - OptimizerAttrs optimizer_attrs = - OptimizerAttrs{SGDOptimizerAttrs{/*lr=*/0.001, - /*momentum=*/0.9, - /*nesterov=*/false, - /*weight_decay=*/0.001}}; - - std::unordered_map - input_tensors; - - DistributedFfHandle device_handle = - create_distributed_ff_handle(ctx, - /*workSpaceSize=*/1024 * 1024, - /*allowTensorOpMathConversion=*/true); - PCGInstance pcg_instance = create_pcg_instance( - /*ctx=*/ctx, - /*mpcg=*/mpcg, - /*optimizer=*/optimizer_attrs, - /*loss=*/std::nullopt, - /*input_tensors=*/input_tensors, - /*profiling_settings=*/ProfilingSettings{0, 0}, - /*device_handle=*/device_handle, - /*iteration_config=*/FFIterationConfig{1_p}); - - // begin training loop - int num_epochs = 1; - for (int i = 0; i < num_epochs; i++) { - perform_all_passes_for_pcg_instance( - /*instance=*/pcg_instance, - /*profiling_settings=*/ProfilingSettings{0, 0}, - /*device_handle=*/device_handle, - /*iteration_config=*/FFIterationConfig{1_p}); - } - }); + /*weights=*/{}); + // output of relu layer + parallel_tensor_guid_t t_relu_1 = + require_only_key(relu_operator_1.outputs, TensorSlotName::OUTPUT); + + // machine + MachineSpaceCoordinate cpu0{0_n, 0_n, DeviceType::CPU}; + MachineSpaceCoordinate cpu1{0_n, 1_n, DeviceType::CPU}; + + ParallelTensorSpaceCoordinate tensor_coord0{ + /* sum_component */ 0_n, + /* discard_copy_component */ 0_n, + /*shard_component*/ FFOrdered{0_n}}; + ParallelTensorSpaceCoordinate tensor_coord1{ + /* sum_component */ 0_n, + /* discard_copy_component */ 1_n, + /*shard_component*/ FFOrdered{0_n}}; + MappedParallelComputationGraph mpcg{ + pcg, + {{inputs_layer.parallel_layer, + MappedOperatorTaskGroup{ + {{cpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, + {inputs_layer_2.parallel_layer, + MappedOperatorTaskGroup{ + {{cpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, + {add_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + {{cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::LHS_INPUT, tensor_coord0}, + {TensorSlotName::RHS_INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}}}}, + {repl_operator_1.parallel_layer, + MappedOperatorTaskGroup{{ + {cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}, + {cpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord1}, + }}}, + }}}, + {relu_operator_1.parallel_layer, + MappedOperatorTaskGroup{{ + {cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}, + {cpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord1}, + {TensorSlotName::OUTPUT, tensor_coord1}, + }}}, + }}}}, + }; + + MappedOperatorTaskGroup loss_mapping{ + {{cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::LOGIT, tensor_coord0}, + }}}}}; + + // instantiate computation graph + LossAttrs loss_attrs = LossAttrs{ + NonconfigurableLossAttrs{LossFunction::CATEGORICAL_CROSSENTROPY}}; + OptimizerAttrs optimizer_attrs = + OptimizerAttrs{SGDOptimizerAttrs{/*lr=*/0.001, + /*momentum=*/0.9, + /*nesterov=*/false, + /*weight_decay=*/0.001}}; + + std::unordered_map + input_tensors; + + DistributedFfHandle device_handle = create_distributed_ff_handle( + ctx, + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true); + PCGInstance pcg_instance = create_pcg_instance( + /*ctx=*/ctx, + /*mpcg=*/mpcg, + /*optimizer=*/optimizer_attrs, + /*loss=*/std::nullopt, + /*input_tensors=*/input_tensors, + /*profiling_settings=*/ProfilingSettings{0, 0}, + /*device_handle=*/device_handle, + /*iteration_config=*/FFIterationConfig{1_p}); + + // begin training loop + int num_epochs = 1; + for (int i = 0; i < num_epochs; i++) { + perform_all_passes_for_pcg_instance( + /*instance=*/pcg_instance, + /*profiling_settings=*/ProfilingSettings{0, 0}, + /*device_handle=*/device_handle, + /*iteration_config=*/FFIterationConfig{1_p}); + } + }); result.wait(); } } @@ -307,7 +320,8 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { // parallel layer -> perform add ParallelLayerAddedResult add_operator_1 = - add_parallel_layer(pcg, make_layer_attrs(add_attrs), + add_parallel_layer(pcg, + make_layer_attrs(add_attrs), { { TensorSlotName::LHS_INPUT, @@ -327,7 +341,8 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { const positive_int replicate_degree = 2_p; ReplicateAttrs repl_attrs = ReplicateAttrs(replicate_degree); ParallelLayerAddedResult repl_operator_1 = - add_parallel_layer(pcg, make_layer_attrs(repl_attrs), + add_parallel_layer(pcg, + make_layer_attrs(repl_attrs), { { TensorSlotName::INPUT, @@ -341,7 +356,8 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { // parallel layer -> perform RelU ParallelLayerAddedResult relu_operator_1 = - add_parallel_layer(pcg, make_layer_attrs(make_relu_attrs()), + add_parallel_layer(pcg, + make_layer_attrs(make_relu_attrs()), /*inputs=*/ { { @@ -357,8 +373,8 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { // machine MachineSpaceCoordinate gpu0{0_n, 0_n, DeviceType::GPU}; MachineSpaceCoordinate gpu1{0_n, 1_n, DeviceType::GPU}; - ParallelTensorSpaceCoordinate tensor_coord0{0_n, 0_n, FFOrdered{0_n}}; - ParallelTensorSpaceCoordinate tensor_coord1{0_n, 1_n, FFOrdered{0_n}}; + ParallelTensorSpaceCoordinate tensor_coord0{0_n, 0_n, FFOrdered{0_n}}; + ParallelTensorSpaceCoordinate tensor_coord1{0_n, 1_n, FFOrdered{0_n}}; MappedParallelComputationGraph mpcg{ pcg, { @@ -374,38 +390,44 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, {add_operator_1.parallel_layer, MappedOperatorTaskGroup{ - {{gpu0, OperatorAtomicTaskShardBinding{{ - {TensorSlotName::LHS_INPUT, tensor_coord0}, - {TensorSlotName::RHS_INPUT, tensor_coord0}, - {TensorSlotName::OUTPUT, tensor_coord0}, - }}}}}}, + {{gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::LHS_INPUT, tensor_coord0}, + {TensorSlotName::RHS_INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}}}}, {repl_operator_1.parallel_layer, - MappedOperatorTaskGroup{{ - {gpu0, OperatorAtomicTaskShardBinding{{ - {TensorSlotName::OUTPUT, tensor_coord0}, - }}}, - {gpu1, OperatorAtomicTaskShardBinding{{ - {TensorSlotName::OUTPUT, tensor_coord1}, - }}}}}}, + MappedOperatorTaskGroup{ + {{gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}, + {gpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord1}, + }}}}}}, {relu_operator_1.parallel_layer, MappedOperatorTaskGroup{{ - {gpu0, OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord0}, - {TensorSlotName::OUTPUT, tensor_coord0}, - }}}, - {gpu1, OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord1}, - {TensorSlotName::OUTPUT, tensor_coord1}, - }}}, + {gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}, + {gpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord1}, + {TensorSlotName::OUTPUT, tensor_coord1}, + }}}, }}}, }, }; MappedOperatorTaskGroup loss_mapping{ - {{gpu0, OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord0}, - {TensorSlotName::LOGIT, tensor_coord0}, - }}}}}; + {{gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::LOGIT, tensor_coord0}, + }}}}}; // instantiate computation graph LossAttrs loss_attrs = LossAttrs{ diff --git a/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.dtg.toml index 8f8f6467c8..2bd0714512 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.dtg.toml @@ -25,3 +25,7 @@ key = "loss" [[values]] type = "::FlexFlow::CopyAttrs" key = "copy" + +[[values]] +type = "::FlexFlow::ReplicateAttrs" +key = "replicate" diff --git a/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc b/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc index 4c1b9d4609..7a28e254aa 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc @@ -25,15 +25,43 @@ bool node_is_copy(DynamicNodeAttrs const &n) { return n.op_attrs.has_value() && n.op_attrs.value().is_copy(); } +static bool is_replicate_invocation(DynamicNodeInvocation const &i) { + if (!i.node_attrs.op_attrs.has_value()) { + return false; + } + TrainingOperationAttrs const &op_attrs = i.node_attrs.op_attrs.value(); + if (op_attrs.is_replicate()) { + return true; + } + return false; +} + bool value_is_mapped(DynamicValueAttrs const &n) { return n.mapping.has_value(); } bool no_part_of_graph_is_copy_inserted(DynamicOpenDataflowGraph const &g) { auto slot_is_mapped = [](DynamicTensorSlot const &) -> bool { return false; }; - - return no_part_of_dynamic_graph_satisfies( - g, node_is_copy, value_is_mapped, slot_is_mapped); + // check all non-replicate invocations + for (DynamicNodeInvocation const &i : g.invocations) { + if (is_replicate_invocation(i)) { + continue; // replicate tensors have mapping set by design + } + if (node_is_copy(i.node_attrs)) { + return false; + } + for (auto const &[slot, value] : i.inputs) { + if (value_is_mapped(value)) { + return false; + } + } + for (auto const &[slot, value] : i.outputs) { + if (value_is_mapped(value)) { + return false; + } + } + } + return true; } bool graph_is_fully_copy_inserted(DynamicOpenDataflowGraph const &g) { @@ -85,6 +113,11 @@ std::unordered_set perform_copy_insertion_for_invocation( std::unordered_map const &unmapped_value_to_mapped_source_value) { + // replicate nodes have no MappedOperatorTaskGroup — + // pass through unchanged, no copies needed + if (is_replicate_invocation(i)) { + return {i}; + } MappedOperatorTaskGroup mapping = assert_unwrap(i.node_attrs.mapping); auto map_tensor = [&](DynamicTensorSlot const &slot, @@ -157,6 +190,14 @@ DynamicOpenDataflowGraph std::unordered_map unmapped_value_to_mapped_source_value; for (DynamicNodeInvocation const &i : g.invocations) { + // replicate nodes have no MappedOperatorTaskGroup — + // output mapping already fully set, maps to itself + if (is_replicate_invocation(i)) { + for (auto const &[slot, value] : i.outputs) { + unmapped_value_to_mapped_source_value.insert(std::pair{value, value}); + } + continue; + } for (auto const &[slot, value] : i.outputs) { unmapped_value_to_mapped_source_value.insert( std::pair{value, diff --git a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc index 246f9a3242..3d48a0dc2b 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc @@ -7,11 +7,129 @@ #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_role.h" #include "utils/containers/generate_map.h" +#include "utils/containers/get_only.h" #include #include #include namespace FlexFlow { +static bidict + get_input_mapping_for_replicate( + MappedParallelComputationGraph const &mpcg, + parallel_layer_guid_t const &replicate_layer) { + + auto [input_slot_name, input_tensor_guid] = + get_only(get_incoming_tensors(mpcg.pcg, replicate_layer)); + + // find the layer that produces this tensor + for (auto const &[layer, _] : get_parallel_layer_attrs_mapping(mpcg.pcg)) { + for (auto const &[slot_name, t] : get_outgoing_tensors(mpcg.pcg, layer)) { + if (t == input_tensor_guid) { + MappedOperatorTaskGroup producer_mapping = mpcg.mapped_tasks.at(layer); + return get_tensor_bindings_for_slot_name(producer_mapping, slot_name); + } + } + } + + PANIC("could not find producer of replicate layer input tensor"); +} + +static std::unordered_map + get_consumers_of_tensor(MappedParallelComputationGraph const &mpcg, + parallel_tensor_guid_t const &tensor) { + std::unordered_map result; + for (auto const &[layer, _] : get_parallel_layer_attrs_mapping(mpcg.pcg)) { + for (auto const &[slot_name, t] : get_incoming_tensors(mpcg.pcg, layer)) { + if (t == tensor) { + result.insert({layer, slot_name}); + } + } + } + return result; +} + +static bidict + build_replicated_output_mapping( + MappedParallelComputationGraph const &mpcg, + parallel_layer_guid_t const &replicate_layer) { + + auto [output_slot_name, output_tensor_guid] = + get_only(get_outgoing_tensors(mpcg.pcg, replicate_layer)); + + auto consumers = get_consumers_of_tensor(mpcg, output_tensor_guid); + ASSERT(!consumers.empty()); + + // union all consumer bindings — each consumer shard maps to a distinct + // (discard_copy, machine) pair since replicas are always on different machines + bidict result; + for (auto const &[consumer_layer, slot_name] : consumers) { + MappedOperatorTaskGroup consumer_mapping = + mpcg.mapped_tasks.at(consumer_layer); + bidict binding = + get_tensor_bindings_for_slot_name(consumer_mapping, slot_name); + for (auto const &[p, m] : binding) { + result.equate(p, m); + } + } + return result; +} + +static DynamicNodeInvocation + build_replicate_invocation(parallel_layer_guid_t const &layer, + ParallelLayerAttrs const &attrs, + MappedParallelComputationGraph const &mpcg) { + auto [input_slot_name, input_tensor_guid] = + get_only(get_incoming_tensors(mpcg.pcg, layer)); + auto incoming = get_incoming_tensors(mpcg.pcg, layer); + ASSERT(!incoming.empty(), + "replicate layer has no incoming tensors — " + "check PCG edge construction in test"); + + ParallelTensorAttrs input_attrs = + get_parallel_tensor_attrs(mpcg.pcg, input_tensor_guid); + bidict input_mapping = + get_input_mapping_for_replicate(mpcg, layer); + + DynamicValueAttrs input_value{ + /*tensor_guid=*/dynamic_tensor_guid_t{input_tensor_guid}, + /*parallel_tensor_shape=*/input_attrs.shape, + /*shard_coord=*/std::nullopt, + /*mapping=*/get_input_mapping_for_replicate(mpcg, layer), + /*accessor=*/std::nullopt, + /*role=*/std::nullopt, + }; + + auto [output_slot_name, output_tensor_guid] = + get_only(get_outgoing_tensors(mpcg.pcg, layer)); + ParallelTensorAttrs output_attrs = + get_parallel_tensor_attrs(mpcg.pcg, output_tensor_guid); + + DynamicValueAttrs output_value{ + /*tensor_guid=*/dynamic_tensor_guid_t{output_tensor_guid}, + /*parallel_tensor_shape=*/output_attrs.shape, + /*shard_coord=*/std::nullopt, + /*mapping=*/build_replicated_output_mapping(mpcg, layer), + /*accessor=*/std::nullopt, + /*role=*/std::nullopt, + }; + DynamicNodeAttrs node_attrs{ + /*task_type=*/std::nullopt, + /*device_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*op_attrs=*/TrainingOperationAttrs{attrs.op_attrs.get()}, + /*pcg_layer_guid=*/dynamic_layer_guid_t{layer}, + /*per_device_op_state=*/std::nullopt, + }; + + DynamicNodeInvocation invocation_node{ + /*inputs=*/{ + {DynamicTensorSlot{input_slot_name, std::nullopt}, input_value}}, + /*node_attrs=*/node_attrs, + /*outputs=*/ + {{DynamicTensorSlot{output_slot_name, std::nullopt}, output_value}}, + }; + return invocation_node; +} DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mapped_pcg( MappedParallelComputationGraph const &mpcg) { @@ -19,6 +137,15 @@ DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mapped_pcg( for (auto const &[layer, attrs] : get_parallel_layer_attrs_mapping(mpcg.pcg)) { + + if (attrs.op_attrs.has()) { + // build replicate invocation + DynamicNodeInvocation repl_inv = + build_replicate_invocation(layer, attrs, mpcg); + result.invocations.emplace(repl_inv); + continue; + } + DynamicNodeAttrs result_attrs{ /*task_type=*/std::nullopt, /*device_coord=*/std::nullopt, diff --git a/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc b/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc index 0cee06368f..aed5f2c4c3 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc @@ -4,6 +4,7 @@ #include "utils/containers/are_all_same.h" #include "utils/containers/merge_disjoint_maps.h" #include "utils/containers/transform.h" +#include "utils/containers/get_only.h" namespace FlexFlow { @@ -109,6 +110,44 @@ DynamicNodeInvocation perform_bwd_pass_expansion_for_invocation( transform(invocation.inputs, to_grad), }; } +static std::unordered_set + perform_pass_expansion_for_replicate( + DynamicNodeInvocation const &invocation) { + + auto const &[input_slot, input] = get_only(invocation.inputs); + auto const &[output_slot, output] = get_only(invocation.outputs); + + // forward: INPUT/FWD → OUTPUT/FWD (copy to replicas) + DynamicNodeInvocation fwd{ + /*inputs=*/{{pass_expand_slot(input_slot, FwbTensorType::FORWARD), + pass_expand_value(input, FwbTensorType::FORWARD)}}, + /*node_attrs=*/ + pass_expand_node(invocation.node_attrs, DynamicTaskType::FWD), + /*outputs=*/ + {{pass_expand_slot(output_slot, FwbTensorType::FORWARD), + pass_expand_value(output, FwbTensorType::FORWARD)}}, + }; + + // backward: OUTPUT/FWD + OUTPUT/GRAD → INPUT/GRAD (reduce gradients) + // The backward node needs the mapping from the output (replicated) + // so it knows which replicas to reduce from + DynamicNodeAttrs bwd_node_attrs = invocation.node_attrs; + bwd_node_attrs.task_type = DynamicTaskType::BWD; + + DynamicNodeInvocation bwd{ + /*inputs=*/{ + {pass_expand_slot(output_slot, FwbTensorType::FORWARD), + pass_expand_value(output, FwbTensorType::FORWARD)}, + {pass_expand_slot(output_slot, FwbTensorType::GRADIENT), + pass_expand_value(output, FwbTensorType::GRADIENT)}, + }, + /*node_attrs=*/bwd_node_attrs, + /*outputs=*/ + {{pass_expand_slot(input_slot, FwbTensorType::GRADIENT), + pass_expand_value(input, FwbTensorType::GRADIENT)}}, + }; + return {fwd, bwd}; +} DynamicOpenDataflowGraph perform_pass_expansion(DynamicOpenDataflowGraph const &g) { @@ -117,6 +156,10 @@ DynamicOpenDataflowGraph DynamicOpenDataflowGraph result = flatmap_dynamic_invocation_set( g, [](DynamicNodeInvocation const &invocation) { + if (invocation.node_attrs.op_attrs.has_value() && + invocation.node_attrs.op_attrs.value().is_replicate()) { + return perform_pass_expansion_for_replicate(invocation); + } if (invocation.inputs.empty()) { return std::unordered_set{ perform_fwd_pass_expansion_for_invocation(invocation), diff --git a/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc b/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc index fb6efb96d0..f30a4d8470 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc @@ -39,7 +39,6 @@ bool graph_is_fully_shard_expanded(DynamicOpenDataflowGraph const &g) { value_is_shard_expanded, slot_is_shard_expanded); } - static bidict restrict_tensor_mapping_keys_to_coord( bidict const @@ -85,6 +84,114 @@ static DynamicNodeInvocation shard_invocation_for_binding( }; } +static std::unordered_set + perform_shard_expansion_for_replicate(DynamicNodeInvocation const &i) { + auto const &[input_slot, input] = get_only(i.inputs); + auto const &[output_slot, output] = get_only(i.outputs); + + bidict input_mapping = + assert_unwrap(input.mapping); + bidict output_mapping = + assert_unwrap(output.mapping); + + return transform(output_mapping.left_values(), + [&](ParallelTensorSpaceCoordinate const &p) { + ParallelTensorSpaceCoordinate input_p{ + /*sum_component=*/p.sum_component, + /*discard_copy_component=*/nonnegative_int{0}, + /*shard_components=*/p.shard_components, + }; + return shard_invocation_for_binding( + i, + output_mapping.at_l(p), + OperatorAtomicTaskShardBinding{{ + {input_slot.slot_name, input_p}, + {output_slot.slot_name, p}, + }}); + }); +} + +static std::unordered_set + perform_shard_expansion_for_replicate_bwd(DynamicNodeInvocation const &i) { + + std::optional output_grad_opt; + std::optional output_fwd_opt; + std::optional output_grad_slot_opt; + std::optional output_fwd_slot_opt; + + for (auto const &[slot, value] : i.inputs) { + if (slot.slot_tensor_role == DynamicTensorRole{FwbTensorType::GRADIENT}) { + output_grad_slot_opt = slot; + output_grad_opt = value; + } else { + output_fwd_slot_opt = slot; + output_fwd_opt = value; + } + } + + DynamicValueAttrs output_grad = assert_unwrap(output_grad_opt); + DynamicValueAttrs output_fwd = assert_unwrap(output_fwd_opt); + DynamicTensorSlot output_grad_slot = assert_unwrap(output_grad_slot_opt); + DynamicTensorSlot output_fwd_slot = assert_unwrap(output_fwd_slot_opt); + auto const &[input_grad_slot, input_grad] = get_only(i.outputs); + + bidict + output_grad_mapping = assert_unwrap(output_grad.mapping); + bidict + input_grad_mapping = assert_unwrap(input_grad.mapping); + + std::unordered_map, + std::unordered_set> + by_shard; + for (auto const &p : output_grad_mapping.left_values()) { + by_shard[p.shard_components].insert(p); + } + + std::unordered_set result; + for (auto const &[shard_components, replica_coords] : by_shard) { + ParallelTensorSpaceCoordinate src_p{ + nonnegative_int{0}, nonnegative_int{0}, shard_components}; + MachineSpaceCoordinate src_machine = input_grad_mapping.at_l(src_p); + + bidict + replica_mapping; + for (auto const &p : replica_coords) { + replica_mapping.equate(p, output_grad_mapping.at_l(p)); + } + + DynamicValueAttrs sharded_output_grad = output_grad; + sharded_output_grad.mapping = replica_mapping; + sharded_output_grad.shard_coord = src_p; + + DynamicValueAttrs sharded_output_fwd = output_fwd; + sharded_output_fwd.mapping = replica_mapping; + sharded_output_fwd.shard_coord = src_p; + + DynamicValueAttrs sharded_input_grad = input_grad; + sharded_input_grad.mapping = + bidict{ + {src_p, src_machine}}; + sharded_input_grad.shard_coord = src_p; + + DynamicNodeAttrs sharded_node = i.node_attrs; + sharded_node.device_coord = src_machine; + + result.insert(DynamicNodeInvocation{ + /*inputs=*/{ + {output_fwd_slot, sharded_output_fwd}, + {output_grad_slot, sharded_output_grad}, + }, + /*node_attrs=*/sharded_node, + /*outputs=*/ + { + {input_grad_slot, sharded_input_grad}, + }, + }); + } + return result; +} + + static std::unordered_set perform_shard_expansion_for_copy(DynamicNodeInvocation const &i) { auto [input_slot, input] = get_only(i.inputs); @@ -121,6 +228,22 @@ std::unordered_set return perform_shard_expansion_for_copy(i); } + // forward replicate + if (i.node_attrs.op_attrs.has_value() && + i.node_attrs.op_attrs.value().is_replicate() && + i.node_attrs.task_type.has_value() && + i.node_attrs.task_type.value() == DynamicTaskType::FWD) { + return perform_shard_expansion_for_replicate(i); + } + + // backward replicate + if (i.node_attrs.op_attrs.has_value() && + i.node_attrs.op_attrs.value().is_replicate() && + i.node_attrs.task_type.has_value() && + i.node_attrs.task_type.value() == DynamicTaskType::BWD) { + return perform_shard_expansion_for_replicate_bwd(i); + } + MappedOperatorTaskGroup mapping = assert_unwrap(i.node_attrs.mapping); std::unordered_set shard_machine_coords = diff --git a/lib/task-spec/src/task-spec/ops/impl/element_binary.cc b/lib/task-spec/src/task-spec/ops/impl/element_binary.cc index 13465d7a5f..c8460af538 100644 --- a/lib/task-spec/src/task-spec/ops/impl/element_binary.cc +++ b/lib/task-spec/src/task-spec/ops/impl/element_binary.cc @@ -36,8 +36,8 @@ static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_profiling_settings(); DeviceType kernel_device_type = acc.get_kernel_device_type(); - ElementBinaryPerDeviceState per_device_state = - acc.get_per_device_op_state().require_element_binary().value(); + std::optional per_device_state = + acc.get_per_device_op_state().require_element_binary(); ElementBinaryAttrs attrs = acc.get_op_attrs().require_element_binary(); device_handle_t handle = acc.get_ff_handle(); @@ -62,8 +62,8 @@ static std::optional backward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_profiling_settings(); DeviceType kernel_device_type = acc.get_kernel_device_type(); - ElementBinaryPerDeviceState per_device_state = - acc.get_per_device_op_state().require_element_binary().value(); + std::optional per_device_state = + acc.get_per_device_op_state().require_element_binary(); ElementBinaryAttrs attrs = acc.get_op_attrs().require_element_binary(); device_handle_t handle = acc.get_ff_handle(); diff --git a/lib/task-spec/src/task-spec/ops/impl/element_unary.cc b/lib/task-spec/src/task-spec/ops/impl/element_unary.cc index d66ff9ab8d..9a092b90b8 100644 --- a/lib/task-spec/src/task-spec/ops/impl/element_unary.cc +++ b/lib/task-spec/src/task-spec/ops/impl/element_unary.cc @@ -35,8 +35,8 @@ static std::optional ProfilingSettings profiling = acc.get_profiling_settings(); DeviceType kernel_device_type = acc.get_kernel_device_type(); - ElementUnaryPerDeviceState per_device_state = - acc.get_per_device_op_state().require_element_unary().value(); + std::optional per_device_state = + acc.get_per_device_op_state().require_element_unary(); return profile(forward_kernel, profiling, @@ -62,8 +62,8 @@ static std::optional ProfilingSettings profiling = acc.get_profiling_settings(); DeviceType kernel_device_type = acc.get_kernel_device_type(); - ElementUnaryPerDeviceState per_device_state = - acc.get_per_device_op_state().require_element_unary().value(); + std::optional per_device_state = + acc.get_per_device_op_state().require_element_unary(); return profile(backward_kernel, profiling, From d033e22f77d08fc6b4d1151ef7d6bf7cc23281cb Mon Sep 17 00:00:00 2001 From: Seema Mirchandaney Date: Tue, 14 Apr 2026 17:10:12 -0700 Subject: [PATCH 03/19] remove ReplicateAttr --- .../src/realm-execution/pcg_instance.cc | 17 ++++++----------- .../src/realm-execution/tasks/task_id_t.cc | 12 +++--------- .../training_operation_attrs.dtg.toml | 4 ---- .../task-spec/dynamic_graph/copy_insertion.cc | 13 +++++-------- ...namic_open_dataflow_graph_from_mapped_pcg.cc | 2 +- .../task-spec/dynamic_graph/pass_expansion.cc | 10 +++++++--- .../task-spec/dynamic_graph/shard_expansion.cc | 16 +++++++++------- 7 files changed, 31 insertions(+), 43 deletions(-) diff --git a/lib/realm-execution/src/realm-execution/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance.cc index a0653c3c37..17c62fe70c 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance.cc @@ -264,23 +264,18 @@ static Realm::Event spawn_dynamic_node_invocation( [&](InputAttrs const &) { return Realm::Event::NO_EVENT; }, [&](WeightAttrs const &) { return Realm::Event::NO_EVENT; }, [&](ReplicateAttrs const &) { - // this should never be reached since replicate - // goes through TrainingOperationAttrs::ReplicateAttrs - PANIC("unexpected replicate in PCGOperatorAttrs path"); - return Realm::Event::NO_EVENT; + if (invocation.node_attrs.task_type.has_value() && + invocation.node_attrs.task_type.value() == + DynamicTaskType::BWD) { + return issue_replicate_bwd(); + } + return issue_copy(); // forward }, [&](auto const &) { return spawn_task(); }, }); }, [&](LossAttrs const &) { return spawn_task(); }, [&](CopyAttrs const &) { return issue_copy(); }, - [&](ReplicateAttrs const &) { - if (invocation.node_attrs.task_type.has_value() && - invocation.node_attrs.task_type.value() == DynamicTaskType::BWD) { - return issue_replicate_bwd(); - } - return issue_copy(); - }, }); } diff --git a/lib/realm-execution/src/realm-execution/tasks/task_id_t.cc b/lib/realm-execution/src/realm-execution/tasks/task_id_t.cc index dd4b0a66ca..0bdc2ca6b5 100644 --- a/lib/realm-execution/src/realm-execution/tasks/task_id_t.cc +++ b/lib/realm-execution/src/realm-execution/tasks/task_id_t.cc @@ -64,9 +64,7 @@ std::optional [](RepartitionAttrs const &attrs) { return task_id_t::REPARTITION_INIT_TASK_ID; }, - [](ReplicateAttrs const &attrs) { - return task_id_t::REPLICATE_INIT_TASK_ID; - }, + [](ReplicateAttrs const &attrs) { return std::nullopt; }, [](ReshapeAttrs const &) { return std::nullopt; }, [](ReverseAttrs const &) { return std::nullopt; }, [](SoftmaxAttrs const &) { return task_id_t::SOFTMAX_INIT_TASK_ID; }, @@ -115,9 +113,7 @@ std::optional [](RepartitionAttrs const &attrs) { return task_id_t::REPARTITION_FWD_TASK_ID; }, - [](ReplicateAttrs const &attrs) { - return task_id_t::REPLICATE_FWD_TASK_ID; - }, + [](ReplicateAttrs const &attrs) { return std::nullopt; }, [](ReshapeAttrs const &) { return task_id_t::RESHAPE_FWD_TASK_ID; }, [](ReverseAttrs const &) { return task_id_t::REVERSE_FWD_TASK_ID; }, [](SoftmaxAttrs const &) { return task_id_t::SOFTMAX_FWD_TASK_ID; }, @@ -166,9 +162,7 @@ std::optional [](RepartitionAttrs const &attrs) { return task_id_t::REPARTITION_BWD_TASK_ID; }, - [](ReplicateAttrs const &attrs) { - return task_id_t::REPLICATE_BWD_TASK_ID; - }, + [](ReplicateAttrs const &attrs) { return std::nullopt; }, [](ReshapeAttrs const &) { return task_id_t::RESHAPE_BWD_TASK_ID; }, [](ReverseAttrs const &) { return task_id_t::REVERSE_BWD_TASK_ID; }, [](SoftmaxAttrs const &) { return task_id_t::SOFTMAX_BWD_TASK_ID; }, diff --git a/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.dtg.toml index 2bd0714512..8f8f6467c8 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.dtg.toml @@ -25,7 +25,3 @@ key = "loss" [[values]] type = "::FlexFlow::CopyAttrs" key = "copy" - -[[values]] -type = "::FlexFlow::ReplicateAttrs" -key = "replicate" diff --git a/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc b/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc index 7a28e254aa..ef41042a51 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc @@ -26,14 +26,11 @@ bool node_is_copy(DynamicNodeAttrs const &n) { } static bool is_replicate_invocation(DynamicNodeInvocation const &i) { - if (!i.node_attrs.op_attrs.has_value()) { - return false; - } - TrainingOperationAttrs const &op_attrs = i.node_attrs.op_attrs.value(); - if (op_attrs.is_replicate()) { - return true; - } - return false; + return i.node_attrs.op_attrs.has_value() && + i.node_attrs.op_attrs.value().has() && + i.node_attrs.op_attrs.value() + .get() + .has(); } bool value_is_mapped(DynamicValueAttrs const &n) { diff --git a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc index 3d48a0dc2b..a4ef156db9 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc @@ -116,7 +116,7 @@ static DynamicNodeInvocation /*task_type=*/std::nullopt, /*device_coord=*/std::nullopt, /*mapping=*/std::nullopt, - /*op_attrs=*/TrainingOperationAttrs{attrs.op_attrs.get()}, + /*op_attrs=*/TrainingOperationAttrs{attrs.op_attrs}, /*pcg_layer_guid=*/dynamic_layer_guid_t{layer}, /*per_device_op_state=*/std::nullopt, }; diff --git a/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc b/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc index aed5f2c4c3..faa1e186c3 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc @@ -2,9 +2,9 @@ #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_role.h" #include "utils/containers/are_all_same.h" +#include "utils/containers/get_only.h" #include "utils/containers/merge_disjoint_maps.h" #include "utils/containers/transform.h" -#include "utils/containers/get_only.h" namespace FlexFlow { @@ -30,6 +30,11 @@ bool graph_is_fully_pass_expanded(DynamicOpenDataflowGraph const &g) { g, node_is_pass_expanded, value_is_pass_expanded, slot_is_pass_expanded); } +static bool is_replicate_attrs(DynamicNodeAttrs const &n) { + return n.op_attrs.has_value() && n.op_attrs.value().has() && + n.op_attrs.value().get().has(); +} + DynamicTensorSlot pass_expand_slot(DynamicTensorSlot const &s, FwbTensorType tensor_type) { ASSERT(!slot_is_pass_expanded(s)); @@ -156,8 +161,7 @@ DynamicOpenDataflowGraph DynamicOpenDataflowGraph result = flatmap_dynamic_invocation_set( g, [](DynamicNodeInvocation const &invocation) { - if (invocation.node_attrs.op_attrs.has_value() && - invocation.node_attrs.op_attrs.value().is_replicate()) { + if (is_replicate_attrs(invocation.node_attrs)) { return perform_pass_expansion_for_replicate(invocation); } if (invocation.inputs.empty()) { diff --git a/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc b/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc index f30a4d8470..d3365ae44c 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc @@ -191,7 +191,6 @@ static std::unordered_set return result; } - static std::unordered_set perform_shard_expansion_for_copy(DynamicNodeInvocation const &i) { auto [input_slot, input] = get_only(i.inputs); @@ -228,18 +227,21 @@ std::unordered_set return perform_shard_expansion_for_copy(i); } + bool const is_replicate = + i.node_attrs.op_attrs.has_value() && + i.node_attrs.op_attrs.value().has() && + i.node_attrs.op_attrs.value() + .get() + .has(); + // forward replicate - if (i.node_attrs.op_attrs.has_value() && - i.node_attrs.op_attrs.value().is_replicate() && - i.node_attrs.task_type.has_value() && + if (is_replicate && i.node_attrs.task_type.has_value() && i.node_attrs.task_type.value() == DynamicTaskType::FWD) { return perform_shard_expansion_for_replicate(i); } // backward replicate - if (i.node_attrs.op_attrs.has_value() && - i.node_attrs.op_attrs.value().is_replicate() && - i.node_attrs.task_type.has_value() && + if (is_replicate && i.node_attrs.task_type.has_value() && i.node_attrs.task_type.value() == DynamicTaskType::BWD) { return perform_shard_expansion_for_replicate_bwd(i); } From 6cd706091420f4e9c776d75dc3464bbf040f5385 Mon Sep 17 00:00:00 2001 From: Seema Mirchandaney Date: Wed, 15 Apr 2026 16:15:21 -0700 Subject: [PATCH 04/19] Add comments to realm reductions, Use existing graph methods --- .../realm-execution/tasks/realm_reduction.h | 69 +++++++++++++++---- ...mic_open_dataflow_graph_from_mapped_pcg.cc | 44 ++++++------ 2 files changed, 79 insertions(+), 34 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/tasks/realm_reduction.h b/lib/realm-execution/include/realm-execution/tasks/realm_reduction.h index d9cf00441b..512e344824 100644 --- a/lib/realm-execution/include/realm-execution/tasks/realm_reduction.h +++ b/lib/realm-execution/include/realm-execution/tasks/realm_reduction.h @@ -1,23 +1,33 @@ -#pragma once +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_REDUCTION_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_REDUCTION_H #include "op-attrs/datatype.dtg.h" #include namespace FlexFlow { -// Sum reduction for float +/** + * \brief Realm Sum Reduction for Float + * \see https://legion.stanford.edu/tutorial/realm/reductions.html + */ struct SumReductionFloat { using LHS = float; using RHS = float; - static constexpr RHS identity = 0.0f; // ← inside struct, constexpr + /** \brief Identity element for addition (0.0) */ + static constexpr RHS identity = 0.0f; + + /** + * \brief Apply reduction: lhs += rhs + * \tparam EXCLUSIVE If true, direct addition; if false, atomic CAS loop + * \param lhs Left-hand side accumulator (modified in place) + * \param rhs Value to add + */ template static void apply(LHS &lhs, RHS rhs) { if (EXCLUSIVE) { lhs += rhs; } else { - // atomic add for non-exclusive - __sync_fetch_and_add((int *)&lhs, *(int *)&rhs); - // proper float atomic add — use union trick + // Atomic float add via CAS loop union { float f; int i; @@ -30,11 +40,18 @@ struct SumReductionFloat { } } + /** + * \brief Fold two RHS values: rhs1 += rhs2 + * \tparam EXCLUSIVE If true, direct addition; if false, atomic CAS loop + * \param rhs1 Accumulator (modified in place) + * \param rhs2 Value to fold in + */ template static void fold(RHS &rhs1, RHS rhs2) { if (EXCLUSIVE) { rhs1 += rhs2; } else { + // Atomic float add via CAS loop union { float f; int i; @@ -48,17 +65,29 @@ struct SumReductionFloat { } }; -// Sum reduction for double +/** + * \brief Realm Sum Reduction for Double + * \see https://legion.stanford.edu/tutorial/realm/reductions.html + */ struct SumReductionDouble { using LHS = double; using RHS = double; - static constexpr RHS identity = 0.0; // ← inside struct, constexpr + /** \brief Identity element for addition (0.0) */ + static constexpr RHS identity = 0.0; + + /** + * \brief Apply reduction: lhs += rhs + * \tparam EXCLUSIVE If true, direct addition; if false, atomic CAS loop + * \param lhs Left-hand side accumulator (modified in place) + * \param rhs Value to add + */ template static void apply(LHS &lhs, RHS rhs) { if (EXCLUSIVE) { lhs += rhs; } else { + // Atomic double add via CAS loop using long long reinterpretation union { double d; long long i; @@ -71,11 +100,18 @@ struct SumReductionDouble { } } + /** + * \brief Fold two RHS values: rhs1 += rhs2 + * \tparam EXCLUSIVE If true, direct addition; if false, atomic CAS loop + * \param rhs1 Accumulator (modified in place) + * \param rhs2 Value to fold in + */ template static void fold(RHS &rhs1, RHS rhs2) { if (EXCLUSIVE) { rhs1 += rhs2; } else { + // Atomic double add via CAS loop using long long reinterpretation union { double d; long long i; @@ -89,12 +125,21 @@ struct SumReductionDouble { } }; -// Reduction op IDs — must not conflict with other registered redops +/** + * \brief Reduction op IDs for sum reductions + * \warning These IDs must not conflict with other registered reduction ops + */ enum SumReductionOpIDs { - REDOP_SUM_FLOAT = 1, - REDOP_SUM_DOUBLE = 2, + REDOP_SUM_FLOAT = 1, ///< Sum reduction op ID for float + REDOP_SUM_DOUBLE = 2, ///< Sum reduction op ID for double }; +/** + * \brief Returns the Realm reduction op ID for a sum reduction over the given datatype + * \param dtype The datatype to look up + * \return The corresponding Realm::ReductionOpID + * \throws PANIC if no sum reduction is registered for the given datatype + */ inline Realm::ReductionOpID get_sum_reduction_op_id(DataType dtype) { switch (dtype) { case DataType::FLOAT: @@ -105,5 +150,5 @@ inline Realm::ReductionOpID get_sum_reduction_op_id(DataType dtype) { PANIC("no sum reduction registered for datatype {}", dtype); } } - } // namespace FlexFlow +#endif diff --git a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc index a4ef156db9..9349341d4b 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc @@ -2,6 +2,7 @@ #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/pcg_operator_attrs.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" #include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" #include "task-spec/dynamic_graph/dynamic_layer_guid_t.dtg.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" @@ -18,31 +19,30 @@ static bidict MappedParallelComputationGraph const &mpcg, parallel_layer_guid_t const &replicate_layer) { - auto [input_slot_name, input_tensor_guid] = - get_only(get_incoming_tensors(mpcg.pcg, replicate_layer)); - - // find the layer that produces this tensor - for (auto const &[layer, _] : get_parallel_layer_attrs_mapping(mpcg.pcg)) { - for (auto const &[slot_name, t] : get_outgoing_tensors(mpcg.pcg, layer)) { - if (t == input_tensor_guid) { - MappedOperatorTaskGroup producer_mapping = mpcg.mapped_tasks.at(layer); - return get_tensor_bindings_for_slot_name(producer_mapping, slot_name); - } - } - } + // get_incoming_edges returns map + // replicate has exactly one input + auto [input_slot_name, input_edge] = + get_only(get_incoming_edges(mpcg.pcg, replicate_layer)); - PANIC("could not find producer of replicate layer input tensor"); + parallel_layer_guid_t producer_layer = get_src_layer(input_edge); + TensorSlotName producer_slot = get_src_layer_output_slot_name(input_edge); + + return get_tensor_bindings_for_slot_name(mpcg.mapped_tasks.at(producer_layer), + producer_slot); } static std::unordered_map get_consumers_of_tensor(MappedParallelComputationGraph const &mpcg, parallel_tensor_guid_t const &tensor) { + parallel_layer_guid_t producer_layer = get_source_layer(mpcg.pcg, tensor); + std::unordered_map result; - for (auto const &[layer, _] : get_parallel_layer_attrs_mapping(mpcg.pcg)) { - for (auto const &[slot_name, t] : get_incoming_tensors(mpcg.pcg, layer)) { - if (t == tensor) { - result.insert({layer, slot_name}); - } + // get_outgoing_edges returns unordered_set + for (ParallelComputationGraphEdge const &edge : + get_outgoing_edges(mpcg.pcg, producer_layer)) { + if (get_parallel_tensor(edge) == tensor) { + result.insert( + std::pair{get_dst_layer(edge), get_dst_layer_input_slot_name(edge)}); } } return result; @@ -76,7 +76,7 @@ static bidict static DynamicNodeInvocation build_replicate_invocation(parallel_layer_guid_t const &layer, - ParallelLayerAttrs const &attrs, + ReplicateAttrs const &attrs, MappedParallelComputationGraph const &mpcg) { auto [input_slot_name, input_tensor_guid] = get_only(get_incoming_tensors(mpcg.pcg, layer)); @@ -116,7 +116,7 @@ static DynamicNodeInvocation /*task_type=*/std::nullopt, /*device_coord=*/std::nullopt, /*mapping=*/std::nullopt, - /*op_attrs=*/TrainingOperationAttrs{attrs.op_attrs}, + /*op_attrs=*/TrainingOperationAttrs{PCGOperatorAttrs{attrs}}, /*pcg_layer_guid=*/dynamic_layer_guid_t{layer}, /*per_device_op_state=*/std::nullopt, }; @@ -140,8 +140,8 @@ DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mapped_pcg( if (attrs.op_attrs.has()) { // build replicate invocation - DynamicNodeInvocation repl_inv = - build_replicate_invocation(layer, attrs, mpcg); + DynamicNodeInvocation repl_inv = build_replicate_invocation( + layer, attrs.op_attrs.get(), mpcg); result.invocations.emplace(repl_inv); continue; } From c50f3846e4f59920cce36792daeef22b2a70d9e0 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Fri, 15 May 2026 17:50:24 -0700 Subject: [PATCH 05/19] Minor PR fixes --- .../mapped_parallel_computation_graph.h | 23 ++ .../parallel_computation_graph.h | 5 + .../mapped_parallel_computation_graph.cc | 43 +++ .../parallel_computation_graph.cc | 15 + .../src/realm-execution/pcg_instance.cc | 38 ++- .../src/realm-execution/test_op_replicate.cc | 298 +++++++++++------- .../sub_parallel_computation_graph.h | 2 +- .../apply_substitution/apply_substitution.cc | 2 +- .../sub_parallel_computation_graph.cc | 2 +- ...mic_open_dataflow_graph_from_mapped_pcg.cc | 32 +- .../get_kwarg_dataflow_value_uses.h | 33 ++ .../include/utils/many_to_one/many_to_one.h | 5 + .../include/utils/one_to_many/one_to_many.h | 5 + .../get_kwarg_dataflow_value_uses.cc | 14 + 14 files changed, 373 insertions(+), 144 deletions(-) create mode 100644 lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.h create mode 100644 lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.cc diff --git a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h index 12c7921282..984a524c21 100644 --- a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h +++ b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h @@ -8,12 +8,35 @@ namespace FlexFlow { std::unordered_set mpcg_get_parallel_layers(MappedParallelComputationGraph const &); + MappedOperatorTaskGroup mpcg_get_mapping_for_layer(MappedParallelComputationGraph const &, parallel_layer_guid_t); ParallelComputationGraph pcg_from_mpcg(MappedParallelComputationGraph const &); +parallel_layer_guid_t mpcg_get_source_layer(MappedParallelComputationGraph const &, + parallel_tensor_guid_t const &); + +ParallelTensorAttrs mpcg_get_parallel_tensor_attrs(MappedParallelComputationGraph const &, + parallel_tensor_guid_t const &); + +std::unordered_map + mpcg_get_incoming_edges(MappedParallelComputationGraph const &, + parallel_layer_guid_t const &); + +std::unordered_set + mpcg_get_outgoing_edges(MappedParallelComputationGraph const &, + parallel_layer_guid_t const &); + +ManyToOne + mpcg_get_incoming_tensors(MappedParallelComputationGraph const &, + parallel_layer_guid_t const &); + +bidict + mpcg_get_outgoing_tensors(MappedParallelComputationGraph const &, + parallel_layer_guid_t const &); + std::unordered_set mpcg_get_edges(MappedParallelComputationGraph const &); diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h index 0368be62bc..1b2d5a0b67 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h @@ -11,6 +11,7 @@ #include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" #include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" #include +#include "pcg/parallel_computation_graph/parallel_tensor_use_t.dtg.h" namespace FlexFlow { @@ -53,6 +54,10 @@ std::unordered_map get_incoming_edges(ParallelComputationGraph const &, parallel_layer_guid_t const &); +std::unordered_set + pcg_get_parallel_tensor_uses(ParallelComputationGraph const &, + parallel_tensor_guid_t const &); + std::unordered_set get_initial_layers(ParallelComputationGraph const &); diff --git a/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc index f4fa946a66..571b89b6dd 100644 --- a/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc @@ -8,6 +8,8 @@ #include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_kwarg_dataflow_graph_view_as_dot.h" #include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/materialize_labelled_kwarg_dataflow_graph_view.h" #include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/rewrite_labelled_kwarg_dataflow_graph_node_labels.h" +#include "utils/bidict/algorithms/bidict_from_map.h" +#include "utils/many_to_one/many_to_one_from_map.h" namespace FlexFlow { @@ -46,6 +48,47 @@ ParallelComputationGraph }; } +parallel_layer_guid_t mpcg_get_source_layer(MappedParallelComputationGraph const &mpcg, + parallel_tensor_guid_t const &t) +{ + return get_source_layer(pcg_from_mpcg(mpcg), t); +} + +ParallelTensorAttrs mpcg_get_parallel_tensor_attrs(MappedParallelComputationGraph const &mpcg, + parallel_tensor_guid_t const &t) +{ + return get_parallel_tensor_attrs(pcg_from_mpcg(mpcg), t); +} + +std::unordered_map + mpcg_get_incoming_edges(MappedParallelComputationGraph const &mpcg, + parallel_layer_guid_t const &l) +{ + return get_incoming_edges(pcg_from_mpcg(mpcg), l); +} + +std::unordered_set + mpcg_get_outgoing_edges(MappedParallelComputationGraph const &mpcg, + parallel_layer_guid_t const &l) +{ + return get_outgoing_edges(pcg_from_mpcg(mpcg), l); +} + +ManyToOne + mpcg_get_incoming_tensors(MappedParallelComputationGraph const &mpcg, + parallel_layer_guid_t const &l) +{ + return many_to_one_from_map(get_incoming_tensors(pcg_from_mpcg(mpcg), l)); +} + + +bidict + mpcg_get_outgoing_tensors(MappedParallelComputationGraph const &mpcg, + parallel_layer_guid_t const &l) +{ + return bidict_from_map(get_outgoing_tensors(pcg_from_mpcg(mpcg), l)); +} + MappedParallelComputationGraph mapped_pcg_from_pcg_and_mapped_op_task_groups( ParallelComputationGraph const &pcg, std::unordered_map const diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index a548ceb65a..2c5197242d 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -36,6 +36,7 @@ #include "utils/graph/node/node.dtg.h" #include "utils/record_formatter.h" #include +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.h" namespace FlexFlow { @@ -206,6 +207,20 @@ std::unordered_map }); } +std::unordered_set + pcg_get_parallel_tensor_uses(ParallelComputationGraph const &pcg, + parallel_tensor_guid_t const &t) +{ + std::unordered_set> raw_uses = + get_kwarg_dataflow_value_uses(pcg.raw_graph, + t.raw_graph_output); + + return transform(raw_uses, [](KwargDataflowInput const &i) { + return parallel_tensor_use_t{i}; + }); +} + + std::unordered_set get_initial_layers(ParallelComputationGraph const &pcg) { std::unordered_set raw_sources = get_initial_nodes(pcg.raw_graph); diff --git a/lib/realm-execution/src/realm-execution/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance.cc index 17c62fe70c..17a6a383e6 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance.cc @@ -218,14 +218,17 @@ static Realm::Event spawn_dynamic_node_invocation( // issue_replicate_bwd lambda auto issue_replicate_bwd = [&]() { - std::optional output_grad_opt; - for (auto const &[slot, value] : invocation.inputs) { - if (slot.slot_tensor_role == DynamicTensorRole{FwbTensorType::GRADIENT}) { - output_grad_opt = value; - } - } - DynamicValueAttrs output_grad = assert_unwrap(output_grad_opt); - DynamicValueAttrs input_grad = get_only(invocation.outputs).second; + + DynamicValueAttrs output_grad = get_only( + values( + filter_keys( + invocation.inputs, + [](DynamicTensorSlot const &s) -> bool { + return s.slot_tensor_role == DynamicTensorRole{FwbTensorType::GRADIENT}; + }))); + + DynamicValueAttrs input_grad = get_only(values(invocation.outputs)); + Realm::RegionInstance dst_inst = tensor_instance_backing.backing.at(input_grad).first; @@ -243,15 +246,16 @@ static Realm::Event spawn_dynamic_node_invocation( Realm::RegionInstance src_inst = tensor_instance_backing.backing.at(replica_key).first; - e = ctx.issue_copy(assert_unwrap(output_grad.parallel_tensor_shape), - src_inst, - assert_unwrap(input_grad.parallel_tensor_shape), - dst_inst, - Realm::ProfilingRequestSet{}, - e, - 0, - redop_id, - false); + e = ctx.issue_copy( + /*src_shape=*/assert_unwrap(output_grad.parallel_tensor_shape), + /*src_inst=*/src_inst, + /*dst_shape=*/assert_unwrap(input_grad.parallel_tensor_shape), + /*dst_inst=*/dst_inst, + /*requests=*/Realm::ProfilingRequestSet{}, + /*wait_on=*/e, + /*priority=*/0, + /*redop_id=*/redop_id, + /*exlusive=*/false); } return e; }; diff --git a/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc b/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc index 632f08d239..cae5ca1756 100644 --- a/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc +++ b/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc @@ -27,6 +27,7 @@ #include "test/utils/doctest/check_kv.h" #include "utils/containers/require_only_key.h" #include +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h" namespace test { @@ -168,67 +169,116 @@ TEST_SUITE(FF_TEST_SUITE) { /* sum_component */ 0_n, /* discard_copy_component */ 1_n, /*shard_component*/ FFOrdered{0_n}}; - MappedParallelComputationGraph mpcg{ - pcg, - {{inputs_layer.parallel_layer, - MappedOperatorTaskGroup{ - {{cpu0, - OperatorAtomicTaskShardBinding{ - {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, - {inputs_layer_2.parallel_layer, - MappedOperatorTaskGroup{ - {{cpu0, - OperatorAtomicTaskShardBinding{ - {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, - {add_operator_1.parallel_layer, - MappedOperatorTaskGroup{ - {{cpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::LHS_INPUT, tensor_coord0}, - {TensorSlotName::RHS_INPUT, tensor_coord0}, + MappedParallelComputationGraph mpcg = mapped_pcg_from_pcg_and_mapped_op_task_groups( + /*pcg=*/pcg, + /*mapped_op_task_groups=*/{ + { + inputs_layer.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ {TensorSlotName::OUTPUT, tensor_coord0}, - }}}}}}, - {repl_operator_1.parallel_layer, - MappedOperatorTaskGroup{{ - {cpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::OUTPUT, tensor_coord0}, - }}}, - {cpu1, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::OUTPUT, tensor_coord1}, - }}}, - }}}, - {relu_operator_1.parallel_layer, - MappedOperatorTaskGroup{{ - {cpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord0}, - {TensorSlotName::OUTPUT, tensor_coord0}, - }}}, - {cpu1, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord1}, - {TensorSlotName::OUTPUT, tensor_coord1}, - }}}, - }}}}, - }; + }}, + }, + }, + }, + }, + { + inputs_layer_2.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + }, + }, + }, + { + add_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::LHS_INPUT, tensor_coord0}, + {TensorSlotName::RHS_INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + }, + }, + }, + { + repl_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + { + cpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord1}, + }}, + }, + }, + }, + }, + { + relu_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + { + cpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord1}, + {TensorSlotName::OUTPUT, tensor_coord1}, + }}, + }, + }, + }, + }, + }); MappedOperatorTaskGroup loss_mapping{ - {{cpu0, + { + { + cpu0, OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord0}, - {TensorSlotName::LOGIT, tensor_coord0}, - }}}}}; + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::LOGIT, tensor_coord0}, + }}, + }, + }, + }; // instantiate computation graph LossAttrs loss_attrs = LossAttrs{ NonconfigurableLossAttrs{LossFunction::CATEGORICAL_CROSSENTROPY}}; OptimizerAttrs optimizer_attrs = - OptimizerAttrs{SGDOptimizerAttrs{/*lr=*/0.001, - /*momentum=*/0.9, - /*nesterov=*/false, - /*weight_decay=*/0.001}}; + OptimizerAttrs{ + SGDOptimizerAttrs{ + /*lr=*/0.001, + /*momentum=*/0.9, + /*nesterov=*/false, + /*weight_decay=*/0.001, + }, + }; std::unordered_map input_tensors; @@ -375,68 +425,102 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { MachineSpaceCoordinate gpu1{0_n, 1_n, DeviceType::GPU}; ParallelTensorSpaceCoordinate tensor_coord0{0_n, 0_n, FFOrdered{0_n}}; ParallelTensorSpaceCoordinate tensor_coord1{0_n, 1_n, FFOrdered{0_n}}; - MappedParallelComputationGraph mpcg{ - pcg, + MappedParallelComputationGraph mpcg = mapped_pcg_from_pcg_and_mapped_op_task_groups( + /*pcg=*/pcg, + /*mapped_op_task_groups=*/{ { - {inputs_layer.parallel_layer, - MappedOperatorTaskGroup{ - {{gpu0, - OperatorAtomicTaskShardBinding{ - {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, - {inputs_layer_2.parallel_layer, - MappedOperatorTaskGroup{ - {{gpu0, - OperatorAtomicTaskShardBinding{ - {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, - {add_operator_1.parallel_layer, - MappedOperatorTaskGroup{ - {{gpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::LHS_INPUT, tensor_coord0}, - {TensorSlotName::RHS_INPUT, tensor_coord0}, - {TensorSlotName::OUTPUT, tensor_coord0}, - }}}}}}, - {repl_operator_1.parallel_layer, - MappedOperatorTaskGroup{ - {{gpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::OUTPUT, tensor_coord0}, - }}}, - {gpu1, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::OUTPUT, tensor_coord1}, - }}}}}}, - {relu_operator_1.parallel_layer, - MappedOperatorTaskGroup{{ - {gpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord0}, - {TensorSlotName::OUTPUT, tensor_coord0}, - }}}, - {gpu1, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord1}, - {TensorSlotName::OUTPUT, tensor_coord1}, - }}}, - }}}, + inputs_layer.parallel_layer, + MappedOperatorTaskGroup{{ + { + gpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, tensor_coord0}}}, + }, + }}, }, - }; - - MappedOperatorTaskGroup loss_mapping{ - {{gpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord0}, - {TensorSlotName::LOGIT, tensor_coord0}, - }}}}}; + { + inputs_layer_2.parallel_layer, + MappedOperatorTaskGroup{{ + { + gpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, tensor_coord0}}}, + }}, + }, + }, + { + add_operator_1.parallel_layer, + MappedOperatorTaskGroup{{ + { + gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::LHS_INPUT, tensor_coord0}, + {TensorSlotName::RHS_INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + }}, + }, + { + repl_operator_1.parallel_layer, + MappedOperatorTaskGroup{{ + { + gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + { + gpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord1}, + }}, + }, + }}, + }, + { + relu_operator_1.parallel_layer, + MappedOperatorTaskGroup{{ + { + gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + { + gpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord1}, + {TensorSlotName::OUTPUT, tensor_coord1}, + }}, + }, + }}, + }, + }); + + MappedOperatorTaskGroup loss_mapping{{ + { + gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::LOGIT, tensor_coord0}, + }}, + }, + }}; // instantiate computation graph LossAttrs loss_attrs = LossAttrs{ NonconfigurableLossAttrs{LossFunction::CATEGORICAL_CROSSENTROPY}}; OptimizerAttrs optimizer_attrs = - OptimizerAttrs{SGDOptimizerAttrs{/*lr=*/0.001, - /*momentum=*/0.9, - /*nesterov=*/false, - /*weight_decay=*/0.001}}; + OptimizerAttrs{ + SGDOptimizerAttrs{ + /*lr=*/0.001, + /*momentum=*/0.9, + /*nesterov=*/false, + /*weight_decay=*/0.001, + }, + }; std::unordered_map input_tensors; diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h index cbfe3ab264..26c98e915c 100644 --- a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h @@ -48,7 +48,7 @@ std::unordered_set get_subgraph_outgoing_edges( std::unordered_set const &); std::unordered_set - get_parallel_tensor_uses(SubParallelComputationGraph const &, + get_open_parallel_tensor_uses(SubParallelComputationGraph const &, open_parallel_tensor_guid_t const &); SubParallelComputationGraphData diff --git a/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc b/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc index 6ed2ef563e..a56555550f 100644 --- a/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc +++ b/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc @@ -109,7 +109,7 @@ SubParallelComputationGraph apply_substitution_from_output_result( input_parallel_tensor_guid_t output_graph_input = output_expr_to_result_sub_pcg_mapping.input_mapping.at_r( output_expr_input); - std::unordered_set uses = get_parallel_tensor_uses( + std::unordered_set uses = get_open_parallel_tensor_uses( substitution_output_graph, open_parallel_tensor_guid_from_input(output_graph_input)); for (parallel_tensor_use_t const &use : uses) { diff --git a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc index 34b8ae1e96..990975bff9 100644 --- a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc +++ b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc @@ -131,7 +131,7 @@ std::unordered_set get_subgraph_incoming_edges( } std::unordered_set - get_parallel_tensor_uses(SubParallelComputationGraph const &spcg, + get_open_parallel_tensor_uses(SubParallelComputationGraph const &spcg, open_parallel_tensor_guid_t const &t) { std::unordered_set> raw_uses = get_open_kwarg_dataflow_value_uses(spcg.raw_graph, diff --git a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc index 0aea7d2324..b23edc0411 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc @@ -23,24 +23,24 @@ static bidict // get_incoming_edges returns map // replicate has exactly one input auto [input_slot_name, input_edge] = - get_only(get_incoming_edges(mpcg.pcg, replicate_layer)); + get_only(mpcg_get_incoming_edges(mpcg, replicate_layer)); parallel_layer_guid_t producer_layer = get_src_layer(input_edge); TensorSlotName producer_slot = get_src_layer_output_slot_name(input_edge); - return get_tensor_bindings_for_slot_name(mpcg.mapped_tasks.at(producer_layer), - producer_slot); + return get_tensor_bindings_for_slot_name( + /*task_group=*/mpcg_get_mapping_for_layer(mpcg, producer_layer), + /*slot_name=*/producer_slot); } static std::unordered_map get_consumers_of_tensor(MappedParallelComputationGraph const &mpcg, parallel_tensor_guid_t const &tensor) { - parallel_layer_guid_t producer_layer = get_source_layer(mpcg.pcg, tensor); + parallel_layer_guid_t producer_layer = mpcg_get_source_layer(mpcg, tensor); std::unordered_map result; // get_outgoing_edges returns unordered_set - for (ParallelComputationGraphEdge const &edge : - get_outgoing_edges(mpcg.pcg, producer_layer)) { + for (ParallelComputationGraphEdge const &edge : mpcg_get_outgoing_edges(mpcg, producer_layer)) { if (get_parallel_tensor(edge) == tensor) { result.insert( std::pair{get_dst_layer(edge), get_dst_layer_input_slot_name(edge)}); @@ -55,7 +55,7 @@ static bidict parallel_layer_guid_t const &replicate_layer) { auto [output_slot_name, output_tensor_guid] = - get_only(get_outgoing_tensors(mpcg.pcg, replicate_layer)); + get_only(mpcg_get_outgoing_tensors(mpcg, replicate_layer)); auto consumers = get_consumers_of_tensor(mpcg, output_tensor_guid); ASSERT(!consumers.empty()); @@ -64,8 +64,7 @@ static bidict // (discard_copy, machine) pair since replicas are always on different machines bidict result; for (auto const &[consumer_layer, slot_name] : consumers) { - MappedOperatorTaskGroup consumer_mapping = - mpcg.mapped_tasks.at(consumer_layer); + MappedOperatorTaskGroup consumer_mapping = mpcg_get_mapping_for_layer(mpcg, consumer_layer); bidict binding = get_tensor_bindings_for_slot_name(consumer_mapping, slot_name); for (auto const &[p, m] : binding) { @@ -80,14 +79,13 @@ static DynamicNodeInvocation ReplicateAttrs const &attrs, MappedParallelComputationGraph const &mpcg) { auto [input_slot_name, input_tensor_guid] = - get_only(get_incoming_tensors(mpcg.pcg, layer)); - auto incoming = get_incoming_tensors(mpcg.pcg, layer); - ASSERT(!incoming.empty(), - "replicate layer has no incoming tensors — " - "check PCG edge construction in test"); + get_only(mpcg_get_incoming_tensors(mpcg, layer).l_to_r()); + + auto incoming = mpcg_get_incoming_tensors(mpcg, layer); + ASSERT(!incoming.empty(), "Replicate layer has no incoming tensors."); ParallelTensorAttrs input_attrs = - get_parallel_tensor_attrs(mpcg.pcg, input_tensor_guid); + mpcg_get_parallel_tensor_attrs(mpcg, input_tensor_guid); bidict input_mapping = get_input_mapping_for_replicate(mpcg, layer); @@ -101,9 +99,9 @@ static DynamicNodeInvocation }; auto [output_slot_name, output_tensor_guid] = - get_only(get_outgoing_tensors(mpcg.pcg, layer)); + get_only(mpcg_get_outgoing_tensors(mpcg, layer)); ParallelTensorAttrs output_attrs = - get_parallel_tensor_attrs(mpcg.pcg, output_tensor_guid); + mpcg_get_parallel_tensor_attrs(mpcg, output_tensor_guid); DynamicValueAttrs output_value{ /*tensor_guid=*/dynamic_tensor_guid_t{output_tensor_guid}, diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.h new file mode 100644 index 0000000000..b5557e9e49 --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.h @@ -0,0 +1,33 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_KWARG_DATAFLOW_VALUE_USES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_KWARG_DATAFLOW_VALUE_USES_H + +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +std::unordered_set> + get_kwarg_dataflow_value_uses( + KwargDataflowGraphView const &g, + KwargDataflowOutput const &v) { + + KwargDataflowEdgeQuery query = + KwargDataflowEdgeQuery{ + /*src_nodes=*/query_set::match_single_value(v.node), + /*src_slots=*/query_set::match_single_value(v.slot_name), + /*dst_nodes=*/query_set::matchall(), + /*dst_slots=*/query_set::matchall(), + }; + + std::unordered_set> edges = + g.query_edges(query); + + return transform( + edges, [&](KwargDataflowEdge const &e) { + return e.dst; + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/many_to_one/many_to_one.h b/lib/utils/include/utils/many_to_one/many_to_one.h index d2f727661c..c73f696172 100644 --- a/lib/utils/include/utils/many_to_one/many_to_one.h +++ b/lib/utils/include/utils/many_to_one/many_to_one.h @@ -19,6 +19,7 @@ #include #include #include +#include "utils/containers/require_same.h" namespace FlexFlow { @@ -106,6 +107,10 @@ struct ManyToOne { return this->m_r_to_l; } + bool empty() const { + return require_same(this->m_l_to_r.empty(), this->m_r_to_l.empty()); + } + private: std::unordered_map m_l_to_r; std::unordered_map> m_r_to_l; diff --git a/lib/utils/include/utils/one_to_many/one_to_many.h b/lib/utils/include/utils/one_to_many/one_to_many.h index 30d84d34c3..7b725fdec1 100644 --- a/lib/utils/include/utils/one_to_many/one_to_many.h +++ b/lib/utils/include/utils/one_to_many/one_to_many.h @@ -23,6 +23,7 @@ #include #include #include +#include "utils/containers/require_same.h" namespace FlexFlow { @@ -114,6 +115,10 @@ struct OneToMany { return this->m_r_to_l; } + bool empty() const { + return require_same(this->m_l_to_r.empty(), this->m_r_to_l.empty()); + } + private: std::unordered_map> m_l_to_r; std::unordered_map m_r_to_l; diff --git a/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.cc b/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.cc new file mode 100644 index 0000000000..2e42863e53 --- /dev/null +++ b/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.cc @@ -0,0 +1,14 @@ +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using SlotName = ordered_value_type<0>; + +template + std::unordered_set> + get_kwarg_dataflow_value_uses( + KwargDataflowGraphView const &, + KwargDataflowOutput const &); + +} // namespace FlexFlow From ac4fffcb307fe1116fde88b3c7aa85599c224a4e Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Fri, 15 May 2026 21:39:42 -0700 Subject: [PATCH 06/19] Clean up pass expansion code --- .../op-attrs/pcg_operator_attrs.dtg.toml | 22 +- .../test/src/op-attrs/ops/element_unary.cc | 13 + .../mapped_parallel_computation_graph.h | 7 + .../parallel_tensor_use_t.h | 14 + .../mapped_parallel_computation_graph.cc | 27 +- .../parallel_tensor_use_t.cc | 13 + .../src/realm-execution/pcg_instance.cc | 1 - .../src/realm-execution/test_op_replicate.cc | 587 ++++++------------ .../output_expr_to_result_sub_pcg_mapping.cc | 4 +- .../src/substitutions/pcg_pattern_match.cc | 4 +- .../dynamic_graph/training_operation_attrs.h | 13 + ...mic_open_dataflow_graph_from_mapped_pcg.cc | 220 +++---- .../task-spec/dynamic_graph/pass_expansion.cc | 85 +-- .../dynamic_graph/training_operation_attrs.cc | 21 + .../task-spec/dynamic_graph/pass_expansion.cc | 270 +++++--- .../binary_merge_disjoint_bidicts.h | 37 ++ .../algorithms/merge_disjoint_bidicts.h | 39 +- lib/utils/include/utils/bidict/bidict.h | 8 + .../utils/containers/transform_pairs.h | 46 ++ .../binary_merge_disjoint_bidicts.cc | 12 + .../algorithms/merge_disjoint_bidicts.cc | 10 + .../src/utils/containers/transform_pairs.cc | 17 + ...ts.cc => binary_merge_disjoint_bidicts.cc} | 12 +- 23 files changed, 792 insertions(+), 690 deletions(-) create mode 100644 lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.h create mode 100644 lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_use_t.cc create mode 100644 lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.h create mode 100644 lib/task-spec/src/task-spec/dynamic_graph/training_operation_attrs.cc create mode 100644 lib/utils/include/utils/bidict/algorithms/binary_merge_disjoint_bidicts.h create mode 100644 lib/utils/include/utils/containers/transform_pairs.h create mode 100644 lib/utils/src/utils/bidict/algorithms/binary_merge_disjoint_bidicts.cc create mode 100644 lib/utils/src/utils/containers/transform_pairs.cc rename lib/utils/test/src/utils/bidict/algorithms/{merge_disjoint_bidicts.cc => binary_merge_disjoint_bidicts.cc} (72%) diff --git a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.toml index 88a65f75c5..f2dd7c9350 100644 --- a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.toml +++ b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.toml @@ -11,13 +11,13 @@ features = [ ] includes = [ - "op-attrs/ops/attention_attrs.dtg.h", - "op-attrs/ops/batch_matmul_attrs.dtg.h", - "op-attrs/ops/batch_norm_attrs.dtg.h", - "op-attrs/ops/broadcast_attrs.dtg.h", - "op-attrs/ops/cast_attrs.dtg.h", - "op-attrs/ops/combine_attrs.dtg.h", - "op-attrs/ops/concat_attrs.dtg.h", + "op-attrs/ops/attention_attrs.dtg.h", + "op-attrs/ops/batch_matmul_attrs.dtg.h", + "op-attrs/ops/batch_norm_attrs.dtg.h", + "op-attrs/ops/broadcast_attrs.dtg.h", + "op-attrs/ops/cast_attrs.dtg.h", + "op-attrs/ops/combine_attrs.dtg.h", + "op-attrs/ops/concat_attrs.dtg.h", "op-attrs/ops/conv_2d_attrs.dtg.h", "op-attrs/ops/dropout_attrs.dtg.h", "op-attrs/ops/element_binary_attrs.dtg.h", @@ -61,7 +61,7 @@ key = "cast" [[values]] type = "::FlexFlow::CombineAttrs" -key = "combine_distributed" +key = "parallel_combine" [[values]] type = "::FlexFlow::ConcatAttrs" @@ -125,15 +125,15 @@ key = "reduce" [[values]] type = "::FlexFlow::ReductionAttrs" -key = "reduce_distributed" +key = "parallel_reduce" [[values]] type = "::FlexFlow::RepartitionAttrs" -key = "partition_distributed" +key = "parallel_partition" [[values]] type = "::FlexFlow::ReplicateAttrs" -key = "replicate_distributed" +key = "parallel_replicate" [[values]] type = "::FlexFlow::ReverseAttrs" diff --git a/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc b/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc index 43b4be06d8..8b2555610e 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc @@ -53,6 +53,19 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } + SUBCASE("discard copy degree > 1") { + positive_int degree = 2_p; + + ParallelTensorShape par_input = make_input( + SumDegree{1_p}, DiscardCopyDegree{degree}, 1_p, 1_p, 1_p); + + tl::expected result = + get_output_shape(attrs, par_input); + tl::expected correct = par_input; + + CHECK(result == correct); + } + SUBCASE("sum degree > 1") { positive_int degree = 2_p; diff --git a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h index 984a524c21..6c24d4c1e1 100644 --- a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h +++ b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h @@ -18,6 +18,9 @@ ParallelComputationGraph pcg_from_mpcg(MappedParallelComputationGraph const &); parallel_layer_guid_t mpcg_get_source_layer(MappedParallelComputationGraph const &, parallel_tensor_guid_t const &); +PCGOperatorAttrs mpcg_get_pcg_op_attrs(MappedParallelComputationGraph const &, + parallel_layer_guid_t const &); + ParallelTensorAttrs mpcg_get_parallel_tensor_attrs(MappedParallelComputationGraph const &, parallel_tensor_guid_t const &); @@ -40,6 +43,10 @@ bidict std::unordered_set mpcg_get_edges(MappedParallelComputationGraph const &); +std::unordered_set + mpcg_get_parallel_tensor_uses(MappedParallelComputationGraph const &, + parallel_tensor_guid_t const &); + MappedParallelComputationGraph mapped_pcg_from_pcg_and_mapped_op_task_groups( ParallelComputationGraph const &pcg, std::unordered_map const diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.h new file mode 100644 index 0000000000..88f1512149 --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_TENSOR_USE_T_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_TENSOR_USE_T_H + +#include "pcg/parallel_computation_graph/parallel_tensor_use_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" + +namespace FlexFlow { + +parallel_layer_guid_t parallel_tensor_use_get_layer(parallel_tensor_use_t const &); +TensorSlotName parallel_tensor_use_get_slot(parallel_tensor_use_t const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc index 571b89b6dd..3b996ccdab 100644 --- a/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc @@ -54,22 +54,28 @@ parallel_layer_guid_t mpcg_get_source_layer(MappedParallelComputationGraph const return get_source_layer(pcg_from_mpcg(mpcg), t); } +PCGOperatorAttrs mpcg_get_pcg_op_attrs(MappedParallelComputationGraph const &mpcg, + parallel_layer_guid_t const &l) +{ + return pcg_get_op_attrs(pcg_from_mpcg(mpcg), l); +} + ParallelTensorAttrs mpcg_get_parallel_tensor_attrs(MappedParallelComputationGraph const &mpcg, - parallel_tensor_guid_t const &t) + parallel_tensor_guid_t const &t) { return get_parallel_tensor_attrs(pcg_from_mpcg(mpcg), t); } std::unordered_map mpcg_get_incoming_edges(MappedParallelComputationGraph const &mpcg, - parallel_layer_guid_t const &l) + parallel_layer_guid_t const &l) { return get_incoming_edges(pcg_from_mpcg(mpcg), l); } std::unordered_set mpcg_get_outgoing_edges(MappedParallelComputationGraph const &mpcg, - parallel_layer_guid_t const &l) + parallel_layer_guid_t const &l) { return get_outgoing_edges(pcg_from_mpcg(mpcg), l); } @@ -84,11 +90,24 @@ ManyToOne bidict mpcg_get_outgoing_tensors(MappedParallelComputationGraph const &mpcg, - parallel_layer_guid_t const &l) + parallel_layer_guid_t const &l) { return bidict_from_map(get_outgoing_tensors(pcg_from_mpcg(mpcg), l)); } +std::unordered_set + mpcg_get_edges(MappedParallelComputationGraph const &mpcg) +{ + return get_edges(pcg_from_mpcg(mpcg)); +} + +std::unordered_set + mpcg_get_parallel_tensor_uses(MappedParallelComputationGraph const &mpcg, + parallel_tensor_guid_t const &t) +{ + return pcg_get_parallel_tensor_uses(pcg_from_mpcg(mpcg), t); +} + MappedParallelComputationGraph mapped_pcg_from_pcg_and_mapped_op_task_groups( ParallelComputationGraph const &pcg, std::unordered_map const diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_use_t.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_use_t.cc new file mode 100644 index 0000000000..e93341d312 --- /dev/null +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_use_t.cc @@ -0,0 +1,13 @@ +#include "pcg/parallel_computation_graph/parallel_tensor_use_t.h" + +namespace FlexFlow { + +parallel_layer_guid_t parallel_tensor_use_get_layer(parallel_tensor_use_t const &u) { + return parallel_layer_guid_t{u.raw_dataflow_input.node}; +} + +TensorSlotName parallel_tensor_use_get_slot(parallel_tensor_use_t const &u) { + return u.raw_dataflow_input.slot_name; +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance.cc index 17a6a383e6..f2edac7f88 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance.cc @@ -216,7 +216,6 @@ static Realm::Event spawn_dynamic_node_invocation( precondition); }; - // issue_replicate_bwd lambda auto issue_replicate_bwd = [&]() { DynamicValueAttrs output_grad = get_only( diff --git a/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc b/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc index cae5ca1756..2523cae798 100644 --- a/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc +++ b/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc @@ -49,6 +49,190 @@ static bool did_loss_decrease(GenericTensorAccessorR const &first_epoch, compare_tensor_accessors_le(last_epoch, first_epoch, allocator)); } +MappedParallelComputationGraph make_test_mpcg_for_device_type(DeviceType device_type) { + positive_int batch_size = 10_p; + positive_int data_dim = 16_p; + positive_int hidden_dim = 32_p; + positive_int output_dim = 1_p; + + TensorShape output_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; + + TensorShape label_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; + + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + TensorShape input_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, data_dim}}, DataType::FLOAT}; + + ParallelLayerAddedResult inputs_layer = + pcg_add_input_layer(pcg, input_tensor_shape); + parallel_tensor_guid_t t_input = + require_only_key(inputs_layer.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult inputs_layer_2 = + pcg_add_input_layer(pcg, input_tensor_shape); + parallel_tensor_guid_t t_input_2 = + require_only_key(inputs_layer_2.outputs, TensorSlotName::OUTPUT); + + ElementBinaryAttrs add_attrs = ElementBinaryAttrs{ + OperatorType::EW_ADD, + DataType::FLOAT, + false, + false, + }; + + ParallelLayerAddedResult add_operator_1 = + add_parallel_layer(pcg, + make_layer_attrs(add_attrs), + { + { + TensorSlotName::LHS_INPUT, + t_input, + }, + { + TensorSlotName::RHS_INPUT, + t_input_2, + }, + }, + /*weights=*/{}); + + parallel_tensor_guid_t t_add_1 = + require_only_key(add_operator_1.outputs, TensorSlotName::OUTPUT); + + positive_int replicate_degree = 2_p; + ReplicateAttrs repl_attrs = ReplicateAttrs{replicate_degree}; + ParallelLayerAddedResult repl_operator_1 = + add_parallel_layer(pcg, + make_layer_attrs(repl_attrs), + { + { + TensorSlotName::INPUT, + t_add_1, + }, + }, + /*weight=*/{}); + + parallel_tensor_guid_t t_repl_1 = + require_only_key(repl_operator_1.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult relu_operator_1 = + add_parallel_layer(pcg, + make_layer_attrs(make_relu_attrs()), + /*inputs=*/ + { + { + TensorSlotName::INPUT, + t_repl_1, + }, + }, + /*weights=*/{}); + + parallel_tensor_guid_t t_relu_1 = + require_only_key(relu_operator_1.outputs, TensorSlotName::OUTPUT); + + MachineSpaceCoordinate cpu0{0_n, 0_n, device_type}; + MachineSpaceCoordinate cpu1{0_n, 1_n, device_type}; + + ParallelTensorSpaceCoordinate tensor_coord0{ + /*sum_component=*/0_n, + /*discard_copy_component=*/0_n, + /*shard_component=*/FFOrdered{0_n}}; + ParallelTensorSpaceCoordinate tensor_coord1{ + /*sum_component=*/0_n, + /*discard_copy_component=*/1_n, + /*shard_component=*/FFOrdered{0_n}}; + + MappedParallelComputationGraph mpcg = mapped_pcg_from_pcg_and_mapped_op_task_groups( + /*pcg=*/pcg, + /*mapped_op_task_groups=*/{ + { + inputs_layer.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + }, + }, + }, + { + inputs_layer_2.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + }, + }, + }, + { + add_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::LHS_INPUT, tensor_coord0}, + {TensorSlotName::RHS_INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + }, + }, + }, + { + repl_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + { + cpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord1}, + }}, + }, + }, + }, + }, + { + relu_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + { + cpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord1}, + {TensorSlotName::OUTPUT, tensor_coord1}, + }}, + }, + }, + }, + }, + }); + + return mpcg; +} + TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("RealmBackend e2e Training Replicate Op (CPU Model Parallelism)") { std::vector fake_args = @@ -61,215 +245,12 @@ TEST_SUITE(FF_TEST_SUITE) { manager.start_controller([](RealmContext &ctx) { Allocator allocator = ctx.get_current_device_allocator(); - positive_int batch_size = 10_p; - positive_int data_dim = 16_p; - positive_int hidden_dim = 32_p; - positive_int output_dim = 1_p; - - // 10,2 - TensorShape output_tensor_shape = TensorShape{ - TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; - - // 10,2 - TensorShape label_tensor_shape = TensorShape{ - TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; - - GenericTensorAccessorW label_tensor = - allocator.allocate_tensor(label_tensor_shape); - - // construct computation graph - ParallelComputationGraph pcg = empty_parallel_computation_graph(); - - // input tensor - // 10, 16 - TensorShape input_tensor_shape = TensorShape{ - TensorDims{FFOrdered{batch_size, data_dim}}, DataType::FLOAT}; - - // parallel layer -> input tensor - ParallelLayerAddedResult inputs_layer = - pcg_add_input_layer(pcg, input_tensor_shape); - parallel_tensor_guid_t t_input = - require_only_key(inputs_layer.outputs, TensorSlotName::OUTPUT); - - // parallel layer -> input tensor 2 - ParallelLayerAddedResult inputs_layer_2 = - pcg_add_input_layer(pcg, input_tensor_shape); - parallel_tensor_guid_t t_input_2 = - require_only_key(inputs_layer_2.outputs, TensorSlotName::OUTPUT); - - // binary ADD attribute - ElementBinaryAttrs add_attrs = ElementBinaryAttrs{ - OperatorType::EW_ADD, - DataType::FLOAT, - false, - false, - }; - - // parallel layer -> perform add - ParallelLayerAddedResult add_operator_1 = - add_parallel_layer(pcg, - make_layer_attrs(add_attrs), - { - { - TensorSlotName::LHS_INPUT, - t_input, - }, - { - TensorSlotName::RHS_INPUT, - t_input_2, - }, - }, - {/* weight */}); - - parallel_tensor_guid_t t_add_1 = - require_only_key(add_operator_1.outputs, TensorSlotName::OUTPUT); - - // parallel layer -> perform replicate - const positive_int replicate_degree = 2_p; - ReplicateAttrs repl_attrs = ReplicateAttrs(replicate_degree); - ParallelLayerAddedResult repl_operator_1 = - add_parallel_layer(pcg, - make_layer_attrs(repl_attrs), - { - { - TensorSlotName::INPUT, - t_add_1, - }, - }, - /*weight=*/{}); - // output of replicate layer - parallel_tensor_guid_t t_repl_1 = - require_only_key(repl_operator_1.outputs, TensorSlotName::OUTPUT); - - // parallel layer -> perform RelU - ParallelLayerAddedResult relu_operator_1 = - add_parallel_layer(pcg, - make_layer_attrs(make_relu_attrs()), - /*inputs=*/ - { - { - TensorSlotName::INPUT, - t_repl_1, - }, - }, - /*weights=*/{}); - // output of relu layer - parallel_tensor_guid_t t_relu_1 = - require_only_key(relu_operator_1.outputs, TensorSlotName::OUTPUT); - - // machine - MachineSpaceCoordinate cpu0{0_n, 0_n, DeviceType::CPU}; - MachineSpaceCoordinate cpu1{0_n, 1_n, DeviceType::CPU}; - - ParallelTensorSpaceCoordinate tensor_coord0{ - /* sum_component */ 0_n, - /* discard_copy_component */ 0_n, - /*shard_component*/ FFOrdered{0_n}}; - ParallelTensorSpaceCoordinate tensor_coord1{ - /* sum_component */ 0_n, - /* discard_copy_component */ 1_n, - /*shard_component*/ FFOrdered{0_n}}; - MappedParallelComputationGraph mpcg = mapped_pcg_from_pcg_and_mapped_op_task_groups( - /*pcg=*/pcg, - /*mapped_op_task_groups=*/{ - { - inputs_layer.parallel_layer, - MappedOperatorTaskGroup{ - { - { - cpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::OUTPUT, tensor_coord0}, - }}, - }, - }, - }, - }, - { - inputs_layer_2.parallel_layer, - MappedOperatorTaskGroup{ - { - { - cpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::OUTPUT, tensor_coord0}, - }}, - }, - }, - }, - }, - { - add_operator_1.parallel_layer, - MappedOperatorTaskGroup{ - { - { - cpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::LHS_INPUT, tensor_coord0}, - {TensorSlotName::RHS_INPUT, tensor_coord0}, - {TensorSlotName::OUTPUT, tensor_coord0}, - }}, - }, - }, - }, - }, - { - repl_operator_1.parallel_layer, - MappedOperatorTaskGroup{ - { - { - cpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::OUTPUT, tensor_coord0}, - }}, - }, - { - cpu1, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::OUTPUT, tensor_coord1}, - }}, - }, - }, - }, - }, - { - relu_operator_1.parallel_layer, - MappedOperatorTaskGroup{ - { - { - cpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord0}, - {TensorSlotName::OUTPUT, tensor_coord0}, - }}, - }, - { - cpu1, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord1}, - {TensorSlotName::OUTPUT, tensor_coord1}, - }}, - }, - }, - }, - }, - }); + MappedParallelComputationGraph mpcg = make_test_mpcg_for_device_type(DeviceType::CPU); - MappedOperatorTaskGroup loss_mapping{ - { - { - cpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord0}, - {TensorSlotName::LOGIT, tensor_coord0}, - }}, - }, - }, - }; - // instantiate computation graph - LossAttrs loss_attrs = LossAttrs{ - NonconfigurableLossAttrs{LossFunction::CATEGORICAL_CROSSENTROPY}}; + std::unordered_map + input_tensors; + OptimizerAttrs optimizer_attrs = OptimizerAttrs{ SGDOptimizerAttrs{ @@ -280,13 +261,11 @@ TEST_SUITE(FF_TEST_SUITE) { }, }; - std::unordered_map - input_tensors; - DistributedFfHandle device_handle = create_distributed_ff_handle( ctx, /*workSpaceSize=*/1024 * 1024, /*allowTensorOpMathConversion=*/true); + PCGInstance pcg_instance = create_pcg_instance( /*ctx=*/ctx, /*mpcg=*/mpcg, @@ -324,194 +303,8 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { manager.start_controller([](RealmContext &ctx) { Allocator allocator = ctx.get_current_device_allocator(); - positive_int batch_size = 10_p; - positive_int data_dim = 16_p; - positive_int hidden_dim = 32_p; - positive_int output_dim = 1_p; - - // 10,2 - TensorShape output_tensor_shape = TensorShape{ - TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; - - // 10,2 - TensorShape label_tensor_shape = TensorShape{ - TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; - - GenericTensorAccessorW label_tensor = - allocator.allocate_tensor(label_tensor_shape); - - // construct computation graph - ParallelComputationGraph pcg = empty_parallel_computation_graph(); - - // input tensor - // 10, 16 - TensorShape input_tensor_shape = TensorShape{ - TensorDims{FFOrdered{batch_size, data_dim}}, DataType::FLOAT}; - - // parallel layer -> input tensor - ParallelLayerAddedResult inputs_layer = - pcg_add_input_layer(pcg, input_tensor_shape); - parallel_tensor_guid_t t_input = - require_only_key(inputs_layer.outputs, TensorSlotName::OUTPUT); - - // parallel layer -> input tensor 2 - ParallelLayerAddedResult inputs_layer_2 = - pcg_add_input_layer(pcg, input_tensor_shape); - parallel_tensor_guid_t t_input_2 = - require_only_key(inputs_layer_2.outputs, TensorSlotName::OUTPUT); - - // binary ADD attribute - ElementBinaryAttrs add_attrs = ElementBinaryAttrs{ - OperatorType::EW_ADD, - DataType::FLOAT, - false, - false, - }; - - // parallel layer -> perform add - ParallelLayerAddedResult add_operator_1 = - add_parallel_layer(pcg, - make_layer_attrs(add_attrs), - { - { - TensorSlotName::LHS_INPUT, - t_input, - }, - { - TensorSlotName::RHS_INPUT, - t_input_2, - }, - }, - {/* weight */}); - - parallel_tensor_guid_t t_add_1 = - require_only_key(add_operator_1.outputs, TensorSlotName::OUTPUT); - - // parallel layer -> perform replicate - const positive_int replicate_degree = 2_p; - ReplicateAttrs repl_attrs = ReplicateAttrs(replicate_degree); - ParallelLayerAddedResult repl_operator_1 = - add_parallel_layer(pcg, - make_layer_attrs(repl_attrs), - { - { - TensorSlotName::INPUT, - t_add_1, - }, - }, - /*weight=*/{}); - // output of replicate layer - parallel_tensor_guid_t t_repl_1 = - require_only_key(repl_operator_1.outputs, TensorSlotName::OUTPUT); - - // parallel layer -> perform RelU - ParallelLayerAddedResult relu_operator_1 = - add_parallel_layer(pcg, - make_layer_attrs(make_relu_attrs()), - /*inputs=*/ - { - { - TensorSlotName::INPUT, - t_repl_1, - }, - }, - /*weights=*/{}); - // output of relu layer - parallel_tensor_guid_t t_relu_1 = - require_only_key(relu_operator_1.outputs, TensorSlotName::OUTPUT); - - // machine - MachineSpaceCoordinate gpu0{0_n, 0_n, DeviceType::GPU}; - MachineSpaceCoordinate gpu1{0_n, 1_n, DeviceType::GPU}; - ParallelTensorSpaceCoordinate tensor_coord0{0_n, 0_n, FFOrdered{0_n}}; - ParallelTensorSpaceCoordinate tensor_coord1{0_n, 1_n, FFOrdered{0_n}}; - MappedParallelComputationGraph mpcg = mapped_pcg_from_pcg_and_mapped_op_task_groups( - /*pcg=*/pcg, - /*mapped_op_task_groups=*/{ - { - inputs_layer.parallel_layer, - MappedOperatorTaskGroup{{ - { - gpu0, - OperatorAtomicTaskShardBinding{ - {{TensorSlotName::OUTPUT, tensor_coord0}}}, - }, - }}, - }, - { - inputs_layer_2.parallel_layer, - MappedOperatorTaskGroup{{ - { - gpu0, - OperatorAtomicTaskShardBinding{ - {{TensorSlotName::OUTPUT, tensor_coord0}}}, - }}, - }, - }, - { - add_operator_1.parallel_layer, - MappedOperatorTaskGroup{{ - { - gpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::LHS_INPUT, tensor_coord0}, - {TensorSlotName::RHS_INPUT, tensor_coord0}, - {TensorSlotName::OUTPUT, tensor_coord0}, - }}, - }, - }}, - }, - { - repl_operator_1.parallel_layer, - MappedOperatorTaskGroup{{ - { - gpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::OUTPUT, tensor_coord0}, - }}, - }, - { - gpu1, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::OUTPUT, tensor_coord1}, - }}, - }, - }}, - }, - { - relu_operator_1.parallel_layer, - MappedOperatorTaskGroup{{ - { - gpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord0}, - {TensorSlotName::OUTPUT, tensor_coord0}, - }}, - }, - { - gpu1, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord1}, - {TensorSlotName::OUTPUT, tensor_coord1}, - }}, - }, - }}, - }, - }); - - MappedOperatorTaskGroup loss_mapping{{ - { - gpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord0}, - {TensorSlotName::LOGIT, tensor_coord0}, - }}, - }, - }}; + MappedParallelComputationGraph mpcg = make_test_mpcg_for_device_type(DeviceType::GPU); - // instantiate computation graph - LossAttrs loss_attrs = LossAttrs{ - NonconfigurableLossAttrs{LossFunction::CATEGORICAL_CROSSENTROPY}}; OptimizerAttrs optimizer_attrs = OptimizerAttrs{ SGDOptimizerAttrs{ diff --git a/lib/substitutions/src/substitutions/apply_substitution/output_expr_to_result_sub_pcg_mapping.cc b/lib/substitutions/src/substitutions/apply_substitution/output_expr_to_result_sub_pcg_mapping.cc index 2ad5b54a17..4374a951f8 100644 --- a/lib/substitutions/src/substitutions/apply_substitution/output_expr_to_result_sub_pcg_mapping.cc +++ b/lib/substitutions/src/substitutions/apply_substitution/output_expr_to_result_sub_pcg_mapping.cc @@ -2,7 +2,7 @@ #include "substitutions/output_graph/output_graph_expr.h" #include "substitutions/sub_parallel_computation_graph.h" #include "utils/bidict/algorithms/bidict_from_pairs.h" -#include "utils/bidict/algorithms/merge_disjoint_bidicts.h" +#include "utils/bidict/algorithms/binary_merge_disjoint_bidicts.h" #include "utils/containers/values.h" #include "utils/containers/zip_values_strict.h" @@ -26,7 +26,7 @@ bidict mapping_for_layer = bidict_from_pairs(values( zip_values_strict(layer_outputs, output_graph_expr_outputs))); - result = merge_disjoint_bidicts(result, mapping_for_layer); + result = binary_merge_disjoint_bidicts(result, mapping_for_layer); } return result; diff --git a/lib/substitutions/src/substitutions/pcg_pattern_match.cc b/lib/substitutions/src/substitutions/pcg_pattern_match.cc index 498fd6c1bf..dbd968d476 100644 --- a/lib/substitutions/src/substitutions/pcg_pattern_match.cc +++ b/lib/substitutions/src/substitutions/pcg_pattern_match.cc @@ -5,7 +5,7 @@ #include "utils/bidict/algorithms/bidict_from_keys_and_values.h" #include "utils/bidict/algorithms/bidict_from_map.h" #include "utils/bidict/algorithms/exhaustive_relational_join.h" -#include "utils/bidict/algorithms/merge_disjoint_bidicts.h" +#include "utils/bidict/algorithms/binary_merge_disjoint_bidicts.h" #include "utils/bidict/algorithms/transform_values.h" #include "utils/containers/is_subseteq_of.h" #include "utils/containers/map_values.h" @@ -34,7 +34,7 @@ bidict exhaustive_relational_join(pattern_node_outputs.reversed(), matched_layer_output_tensors); - result = merge_disjoint_bidicts(result, mapping); + result = binary_merge_disjoint_bidicts(result, mapping); } return result; diff --git a/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.h b/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.h new file mode 100644 index 0000000000..bb8ca4f840 --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_TRAINING_OPERATION_ATTRS_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_TRAINING_OPERATION_ATTRS_H + +#include "task-spec/dynamic_graph/training_operation_attrs.dtg.h" +#include "op-attrs/operator_type.dtg.h" + +namespace FlexFlow { + +bool training_op_attrs_has_op_type(TrainingOperationAttrs const &, OperatorType); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc index b23edc0411..664c615a90 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc @@ -13,15 +13,21 @@ #include #include #include +#include "pcg/parallel_computation_graph/parallel_tensor_use_t.h" +#include "utils/containers/require_only_key.h" +#include "utils/bidict/algorithms/merge_disjoint_bidicts.h" +#include "utils/containers/map_keys_and_values.h" +#include "utils/containers/transform_pairs.h" namespace FlexFlow { + static bidict get_input_mapping_for_replicate( MappedParallelComputationGraph const &mpcg, parallel_layer_guid_t const &replicate_layer) { - // get_incoming_edges returns map - // replicate has exactly one input + ASSERT(mpcg_get_pcg_op_attrs(mpcg, replicate_layer).is_parallel_replicate()); + auto [input_slot_name, input_edge] = get_only(mpcg_get_incoming_edges(mpcg, replicate_layer)); @@ -33,44 +39,32 @@ static bidict /*slot_name=*/producer_slot); } -static std::unordered_map - get_consumers_of_tensor(MappedParallelComputationGraph const &mpcg, - parallel_tensor_guid_t const &tensor) { - parallel_layer_guid_t producer_layer = mpcg_get_source_layer(mpcg, tensor); - - std::unordered_map result; - // get_outgoing_edges returns unordered_set - for (ParallelComputationGraphEdge const &edge : mpcg_get_outgoing_edges(mpcg, producer_layer)) { - if (get_parallel_tensor(edge) == tensor) { - result.insert( - std::pair{get_dst_layer(edge), get_dst_layer_input_slot_name(edge)}); - } - } - return result; -} - static bidict build_replicated_output_mapping( MappedParallelComputationGraph const &mpcg, - parallel_layer_guid_t const &replicate_layer) { + parallel_tensor_guid_t const &output_tensor_guid) { - auto [output_slot_name, output_tensor_guid] = - get_only(mpcg_get_outgoing_tensors(mpcg, replicate_layer)); - - auto consumers = get_consumers_of_tensor(mpcg, output_tensor_guid); + std::unordered_set consumers = mpcg_get_parallel_tensor_uses(mpcg, output_tensor_guid); ASSERT(!consumers.empty()); // union all consumer bindings — each consumer shard maps to a distinct // (discard_copy, machine) pair since replicas are always on different machines - bidict result; - for (auto const &[consumer_layer, slot_name] : consumers) { - MappedOperatorTaskGroup consumer_mapping = mpcg_get_mapping_for_layer(mpcg, consumer_layer); - bidict binding = - get_tensor_bindings_for_slot_name(consumer_mapping, slot_name); - for (auto const &[p, m] : binding) { - result.equate(p, m); - } - } + bidict result = + merge_disjoint_bidicts( + transform(consumers, + [&](parallel_tensor_use_t const &use) + -> bidict + { + parallel_layer_guid_t consumer_layer = parallel_tensor_use_get_layer(use); + TensorSlotName slot_name = parallel_tensor_use_get_slot(use); + + MappedOperatorTaskGroup consumer_mapping = mpcg_get_mapping_for_layer(mpcg, consumer_layer); + bidict binding = + get_tensor_bindings_for_slot_name(consumer_mapping, slot_name); + + return binding; + })); + return result; } @@ -78,14 +72,19 @@ static DynamicNodeInvocation build_replicate_invocation(parallel_layer_guid_t const &layer, ReplicateAttrs const &attrs, MappedParallelComputationGraph const &mpcg) { - auto [input_slot_name, input_tensor_guid] = - get_only(mpcg_get_incoming_tensors(mpcg, layer).l_to_r()); - - auto incoming = mpcg_get_incoming_tensors(mpcg, layer); - ASSERT(!incoming.empty(), "Replicate layer has no incoming tensors."); + ManyToOne incoming = mpcg_get_incoming_tensors(mpcg, layer); + TensorSlotName input_slot_name = TensorSlotName::INPUT; + parallel_tensor_guid_t input_tensor_guid = require_only_key(incoming.l_to_r(), input_slot_name); ParallelTensorAttrs input_attrs = mpcg_get_parallel_tensor_attrs(mpcg, input_tensor_guid); + + bidict outgoing = mpcg_get_outgoing_tensors(mpcg, layer); + TensorSlotName output_slot_name = TensorSlotName::OUTPUT; + parallel_tensor_guid_t output_tensor_guid = require_only_key(outgoing.l_to_r(), output_slot_name); + ParallelTensorAttrs output_attrs = + mpcg_get_parallel_tensor_attrs(mpcg, output_tensor_guid); + bidict input_mapping = get_input_mapping_for_replicate(mpcg, layer); @@ -93,24 +92,20 @@ static DynamicNodeInvocation /*tensor_guid=*/dynamic_tensor_guid_t{input_tensor_guid}, /*parallel_tensor_shape=*/input_attrs.shape, /*shard_coord=*/std::nullopt, - /*mapping=*/get_input_mapping_for_replicate(mpcg, layer), + /*mapping=*/input_mapping, /*accessor=*/std::nullopt, /*role=*/std::nullopt, }; - auto [output_slot_name, output_tensor_guid] = - get_only(mpcg_get_outgoing_tensors(mpcg, layer)); - ParallelTensorAttrs output_attrs = - mpcg_get_parallel_tensor_attrs(mpcg, output_tensor_guid); - DynamicValueAttrs output_value{ /*tensor_guid=*/dynamic_tensor_guid_t{output_tensor_guid}, /*parallel_tensor_shape=*/output_attrs.shape, /*shard_coord=*/std::nullopt, - /*mapping=*/build_replicated_output_mapping(mpcg, layer), + /*mapping=*/build_replicated_output_mapping(mpcg, output_tensor_guid), /*accessor=*/std::nullopt, /*role=*/std::nullopt, }; + DynamicNodeAttrs node_attrs{ /*task_type=*/std::nullopt, /*device_coord=*/std::nullopt, @@ -122,85 +117,92 @@ static DynamicNodeInvocation DynamicNodeInvocation invocation_node{ /*inputs=*/{ - {DynamicTensorSlot{input_slot_name, std::nullopt}, input_value}}, + { + DynamicTensorSlot{input_slot_name, std::nullopt}, + input_value, + }, + }, /*node_attrs=*/node_attrs, - /*outputs=*/ - {{DynamicTensorSlot{output_slot_name, std::nullopt}, output_value}}, + /*outputs=*/{ + { + DynamicTensorSlot{output_slot_name, std::nullopt}, + output_value, + }, + }, }; + return invocation_node; } DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mapped_pcg( MappedParallelComputationGraph const &mpcg) { - DynamicOpenDataflowGraph result = make_empty_dynamic_open_dataflow_graph(); ParallelComputationGraph pcg = pcg_from_mpcg(mpcg); - for (auto const &[layer, attrs] : get_parallel_layer_attrs_mapping(pcg)) { - if (attrs.op_attrs.has()) { + auto mk_invocation = [&](parallel_layer_guid_t layer, ParallelLayerAttrs const &attrs) + -> DynamicNodeInvocation + { + if (attrs.op_attrs.is_parallel_replicate()) { // build replicate invocation DynamicNodeInvocation repl_inv = build_replicate_invocation( - layer, attrs.op_attrs.get(), mpcg); - result.invocations.emplace(repl_inv); - continue; - } - - DynamicNodeAttrs result_attrs{ - /*task_type=*/std::nullopt, - /*device_coord=*/std::nullopt, - /*mapping=*/mpcg_get_mapping_for_layer(mpcg, layer), - /*op_attrs=*/TrainingOperationAttrs{attrs.op_attrs}, - /*pcg_layer_guid=*/dynamic_layer_guid_t{layer}, - /*per_device_op_state=*/std::nullopt, + layer, attrs.op_attrs.require_parallel_replicate(), mpcg); + return repl_inv; + } else { + DynamicNodeAttrs result_attrs{ + /*task_type=*/std::nullopt, + /*device_coord=*/std::nullopt, + /*mapping=*/mpcg_get_mapping_for_layer(mpcg, layer), + /*op_attrs=*/TrainingOperationAttrs{attrs.op_attrs}, + /*pcg_layer_guid=*/dynamic_layer_guid_t{layer}, + /*per_device_op_state=*/std::nullopt, + }; + + auto mk_slot = [](TensorSlotName const &slot_name) -> DynamicTensorSlot { + return DynamicTensorSlot{ + /*slot_name=*/slot_name, + /*slot_tensor_role=*/std::nullopt, + }; + }; + + auto mk_value_attrs = [&](parallel_tensor_guid_t const &tensor) -> DynamicValueAttrs + { + ParallelTensorAttrs attrs = + get_parallel_tensor_attrs(pcg, tensor); + + return DynamicValueAttrs{ + /*tensor_guid=*/dynamic_tensor_guid_t{tensor}, + /*parallel_tensor_shape=*/attrs.shape, + /*shard_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*accessor=*/std::nullopt, + /*role=*/std::nullopt, + }; + }; + + std::unordered_map result_inputs = + map_keys_and_values(get_incoming_tensors(pcg, layer), + mk_slot, + mk_value_attrs); + + std::unordered_map result_outputs = + map_keys_and_values(get_outgoing_tensors(pcg, layer), + mk_slot, + mk_value_attrs); + + DynamicNodeInvocation invocation = DynamicNodeInvocation{ + /*inputs=*/result_inputs, + /*node_attrs=*/result_attrs, + /*outputs=*/result_outputs, + }; + + return invocation; }; + }; - std::unordered_map result_inputs = - transform(get_incoming_tensors(pcg, layer), - [&](TensorSlotName const &slot_name, - parallel_tensor_guid_t const &tensor) { - ParallelTensorAttrs attrs = - get_parallel_tensor_attrs(pcg, tensor); - return std::pair{ - DynamicTensorSlot{ - /*slot_name=*/slot_name, - /*slot_tensor_role=*/std::nullopt, - }, - DynamicValueAttrs{ - /*tensor_guid=*/dynamic_tensor_guid_t{tensor}, - /*parallel_tensor_shape=*/attrs.shape, - /*shard_coord=*/std::nullopt, - /*mapping=*/std::nullopt, - /*accessor=*/std::nullopt, - /*role=*/std::nullopt, - }, - }; - }); - std::unordered_map result_outputs = - transform(get_outgoing_tensors(pcg, layer), - [&](TensorSlotName const &slot_name, - parallel_tensor_guid_t const &tensor) { - ParallelTensorAttrs attrs = - get_parallel_tensor_attrs(pcg, tensor); - return std::pair{ - DynamicTensorSlot{ - /*slot_name=*/slot_name, - /*slot_tensor_role=*/std::nullopt, - }, - DynamicValueAttrs{ - /*tensor_guid=*/dynamic_tensor_guid_t{tensor}, - /*parallel_tensor_shape=*/attrs.shape, - /*shard_coord=*/std::nullopt, - /*mapping=*/std::nullopt, - /*accessor=*/std::nullopt, - /*role=*/std::nullopt, - }, - }; - }); - - result.invocations.emplace(result_inputs, result_attrs, result_outputs); - } - - return result; + return dynamic_open_dataflow_graph_from_invocation_set( + transform_pairs( + unordered_set_of(get_parallel_layer_attrs_mapping(pcg)), + mk_invocation)); } } // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc b/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc index faa1e186c3..25958b5cb7 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc @@ -5,6 +5,7 @@ #include "utils/containers/get_only.h" #include "utils/containers/merge_disjoint_maps.h" #include "utils/containers/transform.h" +#include "task-spec/dynamic_graph/training_operation_attrs.h" namespace FlexFlow { @@ -88,6 +89,8 @@ DynamicNodeInvocation perform_fwd_pass_expansion_for_invocation( DynamicNodeInvocation perform_bwd_pass_expansion_for_invocation( DynamicNodeInvocation const &invocation) { + TrainingOperationAttrs op_attrs = assert_unwrap(invocation.node_attrs.op_attrs); + auto to_fwd = [](DynamicTensorSlot const &k, DynamicValueAttrs const &v) { return std::pair{ pass_expand_slot(k, FwbTensorType::FORWARD), @@ -102,56 +105,37 @@ DynamicNodeInvocation perform_bwd_pass_expansion_for_invocation( }; }; - return DynamicNodeInvocation{ - /*inputs=*/ - merge_disjoint_maps(std::vector{ - transform(invocation.inputs, to_fwd), - transform(invocation.outputs, to_fwd), - transform(invocation.outputs, to_grad), - }), - /*node_attrs=*/ - pass_expand_node(invocation.node_attrs, DynamicTaskType::BWD), - /*outputs=*/ - transform(invocation.inputs, to_grad), - }; -} -static std::unordered_set - perform_pass_expansion_for_replicate( - DynamicNodeInvocation const &invocation) { - - auto const &[input_slot, input] = get_only(invocation.inputs); - auto const &[output_slot, output] = get_only(invocation.outputs); - - // forward: INPUT/FWD → OUTPUT/FWD (copy to replicas) - DynamicNodeInvocation fwd{ - /*inputs=*/{{pass_expand_slot(input_slot, FwbTensorType::FORWARD), - pass_expand_value(input, FwbTensorType::FORWARD)}}, - /*node_attrs=*/ - pass_expand_node(invocation.node_attrs, DynamicTaskType::FWD), - /*outputs=*/ - {{pass_expand_slot(output_slot, FwbTensorType::FORWARD), - pass_expand_value(output, FwbTensorType::FORWARD)}}, - }; - - // backward: OUTPUT/FWD + OUTPUT/GRAD → INPUT/GRAD (reduce gradients) - // The backward node needs the mapping from the output (replicated) - // so it knows which replicas to reduce from - DynamicNodeAttrs bwd_node_attrs = invocation.node_attrs; - bwd_node_attrs.task_type = DynamicTaskType::BWD; - - DynamicNodeInvocation bwd{ - /*inputs=*/{ - {pass_expand_slot(output_slot, FwbTensorType::FORWARD), - pass_expand_value(output, FwbTensorType::FORWARD)}, - {pass_expand_slot(output_slot, FwbTensorType::GRADIENT), - pass_expand_value(output, FwbTensorType::GRADIENT)}, - }, - /*node_attrs=*/bwd_node_attrs, - /*outputs=*/ - {{pass_expand_slot(input_slot, FwbTensorType::GRADIENT), - pass_expand_value(input, FwbTensorType::GRADIENT)}}, + if (training_op_attrs_has_op_type(op_attrs, OperatorType::REPLICATE)) { + auto [input_slot, input] = get_only(invocation.inputs); + auto [output_slot, output] = get_only(invocation.outputs); + + DynamicNodeInvocation bwd{ + /*inputs=*/{ + to_fwd(output_slot, output), + to_grad(output_slot, output), + }, + /*node_attrs=*/ + pass_expand_node(invocation.node_attrs, DynamicTaskType::BWD), + /*outputs=*/{ + to_grad(input_slot, input), + }, + }; + + return bwd; + } else { + return DynamicNodeInvocation{ + /*inputs=*/ + merge_disjoint_maps(std::vector{ + transform(invocation.inputs, to_fwd), + transform(invocation.outputs, to_fwd), + transform(invocation.outputs, to_grad), + }), + /*node_attrs=*/ + pass_expand_node(invocation.node_attrs, DynamicTaskType::BWD), + /*outputs=*/ + transform(invocation.inputs, to_grad), + }; }; - return {fwd, bwd}; } DynamicOpenDataflowGraph @@ -161,9 +145,6 @@ DynamicOpenDataflowGraph DynamicOpenDataflowGraph result = flatmap_dynamic_invocation_set( g, [](DynamicNodeInvocation const &invocation) { - if (is_replicate_attrs(invocation.node_attrs)) { - return perform_pass_expansion_for_replicate(invocation); - } if (invocation.inputs.empty()) { return std::unordered_set{ perform_fwd_pass_expansion_for_invocation(invocation), diff --git a/lib/task-spec/src/task-spec/dynamic_graph/training_operation_attrs.cc b/lib/task-spec/src/task-spec/dynamic_graph/training_operation_attrs.cc new file mode 100644 index 0000000000..d1452242ca --- /dev/null +++ b/lib/task-spec/src/task-spec/dynamic_graph/training_operation_attrs.cc @@ -0,0 +1,21 @@ +#include "task-spec/dynamic_graph/training_operation_attrs.h" +#include "op-attrs/pcg_operator_attrs.h" +#include "utils/overload.h" + +namespace FlexFlow { + +bool training_op_attrs_has_op_type(TrainingOperationAttrs const &op_attrs, OperatorType op_type) { + return op_attrs.visit(overload { + [&](PCGOperatorAttrs const &pcg_op_attrs) -> bool { + return pcg_op_attrs_get_op_type(pcg_op_attrs) == op_type; + }, + [](LossAttrs const &) -> bool { + return false; + }, + [](CopyAttrs const &) -> bool { + return false; + }, + }); +} + +} // namespace FlexFlow diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc index fb087f5295..ed22a8cbde 100644 --- a/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc @@ -2,6 +2,7 @@ #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_role.h" #include +#include "op-attrs/ops/element_unary.h" using namespace ::FlexFlow; @@ -36,6 +37,19 @@ TEST_SUITE(FF_TEST_SUITE) { dynamic_layer_guid_t layer_guid{parallel_layer_guid_t{Node{20}}}; + TrainingOperationAttrs op_attrs = + TrainingOperationAttrs{ + PCGOperatorAttrs{ + LinearAttrs{ + /*out_channels=*/8_p, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*activation=*/std::nullopt, + /*regularizer=*/std::nullopt, + }, + }, + }; + DynamicNodeInvocation invocation = [&]() -> DynamicNodeInvocation { DynamicValueAttrs v1 = mk_value_attrs(0, std::nullopt); DynamicValueAttrs v2 = mk_value_attrs(1, std::nullopt); @@ -46,14 +60,13 @@ TEST_SUITE(FF_TEST_SUITE) { {mk_slot(TensorSlotName::INPUT, std::nullopt), v1}, {mk_slot(TensorSlotName::WEIGHT, std::nullopt), v2}, {mk_slot(TensorSlotName::BIAS, std::nullopt), v1}, - {mk_slot(TensorSlotName::SCALE, std::nullopt), v1}, }, /*node_attrs=*/ DynamicNodeAttrs{ /*task_type=*/std::nullopt, /*device_coord=*/std::nullopt, /*mapping=*/std::nullopt, - /*op_attrs=*/std::nullopt, + /*op_attrs=*/op_attrs, /*layer_guid=*/layer_guid, /*per_device_op_state=*/std::nullopt, }, @@ -79,14 +92,13 @@ TEST_SUITE(FF_TEST_SUITE) { {mk_slot(TensorSlotName::INPUT, fwd_role), v1_fwd}, {mk_slot(TensorSlotName::WEIGHT, fwd_role), v2_fwd}, {mk_slot(TensorSlotName::BIAS, fwd_role), v1_fwd}, - {mk_slot(TensorSlotName::SCALE, fwd_role), v1_fwd}, }, /*node_attrs=*/ DynamicNodeAttrs{ /*task_type=*/DynamicTaskType::FWD, /*device_coord=*/std::nullopt, /*mapping=*/std::nullopt, - /*op_attrs=*/std::nullopt, + /*op_attrs=*/op_attrs, /*layer_guid=*/layer_guid, /*per_device_op_state=*/std::nullopt, }, @@ -130,88 +142,163 @@ TEST_SUITE(FF_TEST_SUITE) { dynamic_layer_guid_t layer_guid{parallel_layer_guid_t{Node{20}}}; - DynamicNodeInvocation invocation = [&]() -> DynamicNodeInvocation { - DynamicValueAttrs v1 = mk_value_attrs(0, std::nullopt); - DynamicValueAttrs v2 = mk_value_attrs(1, std::nullopt); - DynamicValueAttrs v3 = mk_value_attrs(2, std::nullopt); - - return DynamicNodeInvocation{ - /*inputs=*/{ - {mk_slot(TensorSlotName::INPUT, std::nullopt), v1}, - {mk_slot(TensorSlotName::WEIGHT, std::nullopt), v2}, - {mk_slot(TensorSlotName::BIAS, std::nullopt), v1}, - {mk_slot(TensorSlotName::SCALE, std::nullopt), v1}, - }, - /*node_attrs=*/ - DynamicNodeAttrs{ - /*task_type=*/std::nullopt, - /*device_coord=*/std::nullopt, - /*mapping=*/std::nullopt, - /*op_attrs=*/std::nullopt, - /*layer_guid=*/layer_guid, - /*per_device_op_state=*/std::nullopt, - }, - /*outputs=*/ - { - {mk_slot(TensorSlotName::OUTPUT, std::nullopt), v3}, - }, - }; - }(); - - DynamicNodeInvocation result = - perform_bwd_pass_expansion_for_invocation(invocation); - - DynamicNodeInvocation correct = [&]() -> DynamicNodeInvocation { - DynamicTensorRole fwd_role = DynamicTensorRole{FwbTensorType::FORWARD}; - DynamicTensorRole grad_role = DynamicTensorRole{FwbTensorType::GRADIENT}; - - DynamicValueAttrs v1_fwd = mk_value_attrs(0, fwd_role); - DynamicValueAttrs v2_fwd = mk_value_attrs(1, fwd_role); - DynamicValueAttrs v3_fwd = mk_value_attrs(2, fwd_role); - DynamicValueAttrs v1_grad = mk_value_attrs(0, grad_role); - DynamicValueAttrs v2_grad = mk_value_attrs(1, grad_role); - DynamicValueAttrs v3_grad = mk_value_attrs(2, grad_role); - - return DynamicNodeInvocation{ - /*inputs=*/{ - {mk_slot(TensorSlotName::INPUT, fwd_role), v1_fwd}, - {mk_slot(TensorSlotName::WEIGHT, fwd_role), v2_fwd}, - {mk_slot(TensorSlotName::BIAS, fwd_role), v1_fwd}, - {mk_slot(TensorSlotName::SCALE, fwd_role), v1_fwd}, - {mk_slot(TensorSlotName::OUTPUT, fwd_role), v3_fwd}, - {mk_slot(TensorSlotName::OUTPUT, grad_role), v3_grad}, - }, - /*node_attrs=*/ - DynamicNodeAttrs{ - /*pass_type=*/DynamicTaskType::BWD, - /*device_coord=*/std::nullopt, - /*mapping=*/std::nullopt, - /*op_attrs=*/std::nullopt, - /*layer_guid=*/layer_guid, - /*per_device_op_state=*/std::nullopt, + DynamicValueAttrs v1 = mk_value_attrs(0, std::nullopt); + DynamicValueAttrs v2 = mk_value_attrs(1, std::nullopt); + DynamicValueAttrs v3 = mk_value_attrs(2, std::nullopt); + + DynamicTensorRole fwd_role = DynamicTensorRole{FwbTensorType::FORWARD}; + DynamicTensorRole grad_role = DynamicTensorRole{FwbTensorType::GRADIENT}; + + DynamicValueAttrs v1_fwd = mk_value_attrs(0, fwd_role); + DynamicValueAttrs v2_fwd = mk_value_attrs(1, fwd_role); + DynamicValueAttrs v3_fwd = mk_value_attrs(2, fwd_role); + DynamicValueAttrs v1_grad = mk_value_attrs(0, grad_role); + DynamicValueAttrs v2_grad = mk_value_attrs(1, grad_role); + DynamicValueAttrs v3_grad = mk_value_attrs(2, grad_role); + + SUBCASE("normal operator") { + TrainingOperationAttrs op_attrs = + TrainingOperationAttrs{ + PCGOperatorAttrs{ + LinearAttrs{ + /*out_channels=*/8_p, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*activation=*/std::nullopt, + /*regularizer=*/std::nullopt, + }, }, - /*outputs=*/ - { - {mk_slot(TensorSlotName::INPUT, grad_role), v1_grad}, - {mk_slot(TensorSlotName::WEIGHT, grad_role), v2_grad}, - {mk_slot(TensorSlotName::BIAS, grad_role), v1_grad}, - {mk_slot(TensorSlotName::SCALE, grad_role), v1_grad}, + }; + + DynamicNodeInvocation invocation = [&]() -> DynamicNodeInvocation { + return DynamicNodeInvocation{ + /*inputs=*/{ + {mk_slot(TensorSlotName::INPUT, std::nullopt), v1}, + {mk_slot(TensorSlotName::WEIGHT, std::nullopt), v2}, + {mk_slot(TensorSlotName::BIAS, std::nullopt), v1}, + }, + /*node_attrs=*/ + DynamicNodeAttrs{ + /*task_type=*/std::nullopt, + /*device_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*op_attrs=*/op_attrs, + /*layer_guid=*/layer_guid, + /*per_device_op_state=*/std::nullopt, + }, + /*outputs=*/ + { + {mk_slot(TensorSlotName::OUTPUT, std::nullopt), v3}, + }, + }; + }(); + + DynamicNodeInvocation result = + perform_bwd_pass_expansion_for_invocation(invocation); + + DynamicNodeInvocation correct = [&]() -> DynamicNodeInvocation { + return DynamicNodeInvocation{ + /*inputs=*/{ + {mk_slot(TensorSlotName::INPUT, fwd_role), v1_fwd}, + {mk_slot(TensorSlotName::WEIGHT, fwd_role), v2_fwd}, + {mk_slot(TensorSlotName::BIAS, fwd_role), v1_fwd}, + {mk_slot(TensorSlotName::OUTPUT, fwd_role), v3_fwd}, + {mk_slot(TensorSlotName::OUTPUT, grad_role), v3_grad}, + }, + /*node_attrs=*/ + DynamicNodeAttrs{ + /*pass_type=*/DynamicTaskType::BWD, + /*device_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*op_attrs=*/op_attrs, + /*layer_guid=*/layer_guid, + /*per_device_op_state=*/std::nullopt, + }, + /*outputs=*/ + { + {mk_slot(TensorSlotName::INPUT, grad_role), v1_grad}, + {mk_slot(TensorSlotName::WEIGHT, grad_role), v2_grad}, + {mk_slot(TensorSlotName::BIAS, grad_role), v1_grad}, + }, + }; + }(); + + ASSERT(result == correct); + } + + SUBCASE("replicate operator optimization") { + TrainingOperationAttrs op_attrs = + TrainingOperationAttrs{ + PCGOperatorAttrs{ + ReplicateAttrs{ + /*replicate_degree=*/2_p, + }, }, - }; - }(); - - ASSERT(result == correct); + }; + + DynamicNodeInvocation invocation = [&]() -> DynamicNodeInvocation { + return DynamicNodeInvocation{ + /*inputs=*/{ + {mk_slot(TensorSlotName::INPUT, std::nullopt), v1}, + }, + /*node_attrs=*/ + DynamicNodeAttrs{ + /*task_type=*/std::nullopt, + /*device_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*op_attrs=*/op_attrs, + /*layer_guid=*/layer_guid, + /*per_device_op_state=*/std::nullopt, + }, + /*outputs=*/ + { + {mk_slot(TensorSlotName::OUTPUT, std::nullopt), v2}, + }, + }; + }(); + + DynamicNodeInvocation result = + perform_bwd_pass_expansion_for_invocation(invocation); + + DynamicNodeInvocation correct = [&]() -> DynamicNodeInvocation { + DynamicTensorRole fwd_role = DynamicTensorRole{FwbTensorType::FORWARD}; + DynamicTensorRole grad_role = DynamicTensorRole{FwbTensorType::GRADIENT}; + + return DynamicNodeInvocation{ + /*inputs=*/{ + {mk_slot(TensorSlotName::OUTPUT, fwd_role), v2_fwd}, + {mk_slot(TensorSlotName::OUTPUT, grad_role), v2_grad}, + }, + /*node_attrs=*/ + DynamicNodeAttrs{ + /*pass_type=*/DynamicTaskType::BWD, + /*device_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*op_attrs=*/op_attrs, + /*layer_guid=*/layer_guid, + /*per_device_op_state=*/std::nullopt, + }, + /*outputs=*/ + { + {mk_slot(TensorSlotName::INPUT, grad_role), v1_grad}, + }, + }; + }(); + + ASSERT(result == correct); + } } TEST_CASE("perform_pass_expansion(DynamicOpenDataflowGraph)") { auto mk_node_attrs = [](size_t layer_id, + TrainingOperationAttrs const &op_attrs, std::optional const &pass_type) -> DynamicNodeAttrs { return DynamicNodeAttrs{ /*pass_type=*/pass_type, /*device_coord=*/std::nullopt, /*mapping=*/std::nullopt, - /*op_attrs=*/std::nullopt, + /*op_attrs=*/op_attrs, /*layer_guid=*/ dynamic_layer_guid_t{parallel_layer_guid_t{Node{layer_id}}}, /*per_device_op_state=*/std::nullopt, @@ -236,9 +323,32 @@ TEST_SUITE(FF_TEST_SUITE) { }; }; + TrainingOperationAttrs input_op_attrs = TrainingOperationAttrs{ + PCGOperatorAttrs{ + InputAttrs{ + TensorShape{ + TensorDims{ + FFOrdered{ + 4_p, + 8_p, + }, + }, + DataType::FLOAT, + }, + }, + }, + }; + + TrainingOperationAttrs relu_op_attrs = TrainingOperationAttrs{ + PCGOperatorAttrs{ + make_relu_attrs(), + }, + }; + + DynamicOpenDataflowGraph input = [&]() -> DynamicOpenDataflowGraph { - DynamicNodeAttrs n1 = mk_node_attrs(10, std::nullopt); - DynamicNodeAttrs n2 = mk_node_attrs(11, std::nullopt); + DynamicNodeAttrs n1 = mk_node_attrs(10, input_op_attrs, std::nullopt); + DynamicNodeAttrs n2 = mk_node_attrs(11, relu_op_attrs, std::nullopt); DynamicValueAttrs v1 = mk_value_attrs(0, std::nullopt); DynamicValueAttrs v2 = mk_value_attrs(1, std::nullopt); @@ -286,10 +396,10 @@ TEST_SUITE(FF_TEST_SUITE) { DynamicOpenDataflowGraph result = perform_pass_expansion(input); DynamicOpenDataflowGraph correct = [&]() -> DynamicOpenDataflowGraph { - DynamicNodeAttrs n1_fwd = mk_node_attrs(10, DynamicTaskType::FWD); - DynamicNodeAttrs n2_fwd = mk_node_attrs(11, DynamicTaskType::FWD); - DynamicNodeAttrs n1_bwd = mk_node_attrs(10, DynamicTaskType::BWD); - DynamicNodeAttrs n2_bwd = mk_node_attrs(11, DynamicTaskType::BWD); + DynamicNodeAttrs n1_fwd = mk_node_attrs(10, input_op_attrs, DynamicTaskType::FWD); + DynamicNodeAttrs n2_fwd = mk_node_attrs(11, relu_op_attrs, DynamicTaskType::FWD); + DynamicNodeAttrs n1_bwd = mk_node_attrs(10, input_op_attrs, DynamicTaskType::BWD); + DynamicNodeAttrs n2_bwd = mk_node_attrs(11, relu_op_attrs, DynamicTaskType::BWD); DynamicValueAttrs v1_activation = mk_value_attrs(0, mk_dynamic_tensor_role_fwd()); diff --git a/lib/utils/include/utils/bidict/algorithms/binary_merge_disjoint_bidicts.h b/lib/utils/include/utils/bidict/algorithms/binary_merge_disjoint_bidicts.h new file mode 100644 index 0000000000..5b0bb45910 --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/binary_merge_disjoint_bidicts.h @@ -0,0 +1,37 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BINARY_MERGE_DISJOINT_BIDICTS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BINARY_MERGE_DISJOINT_BIDICTS_H + +#include "utils/bidict/algorithms/left_entries.h" +#include "utils/bidict/algorithms/right_entries.h" +#include "utils/bidict/bidict.h" +#include "utils/containers/are_disjoint.h" +#include "utils/exception.h" + +namespace FlexFlow { + +template +bidict binary_merge_disjoint_bidicts(bidict const &lhs, + bidict const &rhs) { + if (!are_disjoint(left_entries(lhs), left_entries(rhs))) { + throw mk_runtime_error( + fmt::format("Left entries of {} and {} are non-disjoint", lhs, rhs)); + } + if (!are_disjoint(right_entries(lhs), right_entries(rhs))) { + throw mk_runtime_error( + fmt::format("Right entries of {} and {} are non-disjoint", lhs, rhs)); + } + + bidict result; + for (auto const &kv : lhs) { + result.equate_strict(kv.first, kv.second); + } + for (auto const &kv : rhs) { + result.equate_strict(kv.first, kv.second); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h b/lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h index 97e7334c26..f2104fd113 100644 --- a/lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h +++ b/lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h @@ -1,35 +1,22 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_MERGE_DISJOINT_BIDICTS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_MERGE_DISJOINT_BIDICTS_H -#include "utils/bidict/algorithms/left_entries.h" -#include "utils/bidict/algorithms/right_entries.h" -#include "utils/bidict/bidict.h" -#include "utils/containers/are_disjoint.h" -#include "utils/exception.h" +#include "utils/containers/foldl.h" +#include "utils/bidict/algorithms/binary_merge_disjoint_bidicts.h" namespace FlexFlow { -template -bidict merge_disjoint_bidicts(bidict const &lhs, - bidict const &rhs) { - if (!are_disjoint(left_entries(lhs), left_entries(rhs))) { - throw mk_runtime_error( - fmt::format("Left entries of {} and {} are non-disjoint", lhs, rhs)); - } - if (!are_disjoint(right_entries(lhs), right_entries(rhs))) { - throw mk_runtime_error( - fmt::format("Right entries of {} and {} are non-disjoint", lhs, rhs)); - } - - bidict result; - for (auto const &kv : lhs) { - result.equate(kv.first, kv.second); - } - for (auto const &kv : rhs) { - result.equate(kv.first, kv.second); - } - - return result; +template +bidict merge_disjoint_bidicts(C const &c) { + bidict empty = {}; + return foldl(c, + /*init=*/empty, + [](bidict const &lhs, + bidict const &rhs) { + return binary_merge_disjoint_bidicts(lhs, rhs); + }); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/bidict/bidict.h b/lib/utils/include/utils/bidict/bidict.h index 5dbd1c603d..2d8c5d23a8 100644 --- a/lib/utils/include/utils/bidict/bidict.h +++ b/lib/utils/include/utils/bidict/bidict.h @@ -213,6 +213,14 @@ struct bidict { return this->fwd_map; } + std::unordered_map const &l_to_r() const { + return this->fwd_map; + } + + std::unordered_map const &r_to_l() const { + return this->bwd_map; + } + bidict(std::unordered_map const &fwd_map, std::unordered_map const &bwd_map) : fwd_map(fwd_map), bwd_map(bwd_map) {} diff --git a/lib/utils/include/utils/containers/transform_pairs.h b/lib/utils/include/utils/containers/transform_pairs.h new file mode 100644 index 0000000000..c01b50554f --- /dev/null +++ b/lib/utils/include/utils/containers/transform_pairs.h @@ -0,0 +1,46 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_TRANSFORM_PAIRS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_TRANSFORM_PAIRS_H + +#include "utils/containers/transform.h" + +namespace FlexFlow { + +template > +std::vector transform_pairs(std::vector> const &c, F &&f) { + auto ff = [&](std::pair const &p) -> Out { + return f(p.first, p.second); + }; + + return transform(c, ff); +} + +template > +std::unordered_set transform_pairs(std::unordered_set> const &c, F &&f) { + auto ff = [&](std::pair const &p) -> Out { + return f(p.first, p.second); + }; + + return transform(c, ff); +} + +template > +std::set transform_pairs(std::set> const &c, F &&f) { + auto ff = [&](std::pair const &p) -> Out { + return f(p.first, p.second); + }; + + return transform(c, ff); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/src/utils/bidict/algorithms/binary_merge_disjoint_bidicts.cc b/lib/utils/src/utils/bidict/algorithms/binary_merge_disjoint_bidicts.cc new file mode 100644 index 0000000000..8650de44f6 --- /dev/null +++ b/lib/utils/src/utils/bidict/algorithms/binary_merge_disjoint_bidicts.cc @@ -0,0 +1,12 @@ +#include "utils/bidict/algorithms/binary_merge_disjoint_bidicts.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; + +template + bidict binary_merge_disjoint_bidicts(bidict const &, bidict const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc b/lib/utils/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc index 754b8d2e90..2c27821d3b 100644 --- a/lib/utils/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc +++ b/lib/utils/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc @@ -1 +1,11 @@ #include "utils/bidict/algorithms/merge_disjoint_bidicts.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; + +template bidict merge_disjoint_bidicts(std::vector> const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/transform_pairs.cc b/lib/utils/src/utils/containers/transform_pairs.cc new file mode 100644 index 0000000000..241f1ad425 --- /dev/null +++ b/lib/utils/src/utils/containers/transform_pairs.cc @@ -0,0 +1,17 @@ +#include "utils/containers/transform_pairs.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using L = value_type<0>; +using R = value_type<1>; +using Out = value_type<2>; +using F = std::function; + +template + std::vector transform_pairs(std::vector> const &, F &&); + +template + std::unordered_set transform_pairs(std::unordered_set> const &, F &&); + +} // namespace FlexFlow diff --git a/lib/utils/test/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc b/lib/utils/test/src/utils/bidict/algorithms/binary_merge_disjoint_bidicts.cc similarity index 72% rename from lib/utils/test/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc rename to lib/utils/test/src/utils/bidict/algorithms/binary_merge_disjoint_bidicts.cc index 0a1babd9f9..8a3371b8d8 100644 --- a/lib/utils/test/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc +++ b/lib/utils/test/src/utils/bidict/algorithms/binary_merge_disjoint_bidicts.cc @@ -1,17 +1,17 @@ -#include "utils/bidict/algorithms/merge_disjoint_bidicts.h" +#include "utils/bidict/algorithms/binary_merge_disjoint_bidicts.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("merge_disjoint_bidicts") { + TEST_CASE("binary_merge_disjoint_bidicts") { SUBCASE("disjoint keys and values") { bidict bd1 = {{1, "one"}, {2, "two"}}; bidict bd2 = {{3, "three"}, {4, "four"}}; - bidict result = merge_disjoint_bidicts(bd1, bd2); + bidict result = binary_merge_disjoint_bidicts(bd1, bd2); bidict correct = { {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; @@ -22,21 +22,21 @@ TEST_SUITE(FF_TEST_SUITE) { bidict bd1 = {{1, "one"}, {2, "two"}}; bidict bd2 = {{2, "three"}, {3, "four"}}; - CHECK_THROWS(merge_disjoint_bidicts(bd1, bd2)); + CHECK_THROWS(binary_merge_disjoint_bidicts(bd1, bd2)); } SUBCASE("overlapping key, same associated value") { bidict bd1 = {{1, "one"}, {2, "two"}}; bidict bd2 = {{2, "two"}, {3, "three"}}; - CHECK_THROWS(merge_disjoint_bidicts(bd1, bd2)); + CHECK_THROWS(binary_merge_disjoint_bidicts(bd1, bd2)); } SUBCASE("overlapping values") { bidict bd1 = {{1, "one"}, {2, "two"}}; bidict bd2 = {{3, "two"}, {4, "four"}}; - CHECK_THROWS(merge_disjoint_bidicts(bd1, bd2)); + CHECK_THROWS(binary_merge_disjoint_bidicts(bd1, bd2)); } } } From fdf4fe5e74d4ede4cc21da923ee0aaedf5771351 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Fri, 15 May 2026 21:42:09 -0700 Subject: [PATCH 07/19] Remove unnecessary is_replicate_attrs function --- lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc | 5 ----- 1 file changed, 5 deletions(-) diff --git a/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc b/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc index 25958b5cb7..f4960fe67a 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc @@ -31,11 +31,6 @@ bool graph_is_fully_pass_expanded(DynamicOpenDataflowGraph const &g) { g, node_is_pass_expanded, value_is_pass_expanded, slot_is_pass_expanded); } -static bool is_replicate_attrs(DynamicNodeAttrs const &n) { - return n.op_attrs.has_value() && n.op_attrs.value().has() && - n.op_attrs.value().get().has(); -} - DynamicTensorSlot pass_expand_slot(DynamicTensorSlot const &s, FwbTensorType tensor_type) { ASSERT(!slot_is_pass_expanded(s)); From dce15e23118a128c65d7c35f7e57da91011329f4 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 21 May 2026 12:14:33 -0700 Subject: [PATCH 08/19] Format. --- .../test/src/op-attrs/ops/element_unary.cc | 5 +- .../mapped_parallel_computation_graph.h | 20 +-- .../parallel_computation_graph.h | 2 +- .../parallel_tensor_use_t.h | 5 +- .../mapped_parallel_computation_graph.cc | 43 +++--- .../parallel_computation_graph.cc | 9 +- .../parallel_tensor_use_t.cc | 3 +- .../sub_parallel_computation_graph.h | 2 +- .../apply_substitution/apply_substitution.cc | 7 +- .../src/substitutions/pcg_pattern_match.cc | 2 +- .../sub_parallel_computation_graph.cc | 2 +- .../dynamic_graph/training_operation_attrs.h | 5 +- ...mic_open_dataflow_graph_from_mapped_pcg.cc | 129 +++++++++--------- .../task-spec/dynamic_graph/pass_expansion.cc | 16 ++- .../dynamic_graph/training_operation_attrs.cc | 19 ++- .../task-spec/dynamic_graph/pass_expansion.cc | 95 ++++++------- .../algorithms/merge_disjoint_bidicts.h | 5 +- .../utils/containers/transform_pairs.h | 3 +- .../get_kwarg_dataflow_value_uses.h | 33 ++--- .../include/utils/many_to_one/many_to_one.h | 2 +- .../include/utils/one_to_many/one_to_many.h | 2 +- .../binary_merge_disjoint_bidicts.cc | 4 +- .../src/utils/containers/transform_pairs.cc | 8 +- .../get_kwarg_dataflow_value_uses.cc | 8 +- 24 files changed, 210 insertions(+), 219 deletions(-) diff --git a/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc b/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc index 8b2555610e..09e49a123c 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc @@ -56,8 +56,8 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("discard copy degree > 1") { positive_int degree = 2_p; - ParallelTensorShape par_input = make_input( - SumDegree{1_p}, DiscardCopyDegree{degree}, 1_p, 1_p, 1_p); + ParallelTensorShape par_input = + make_input(SumDegree{1_p}, DiscardCopyDegree{degree}, 1_p, 1_p, 1_p); tl::expected result = get_output_shape(attrs, par_input); @@ -74,6 +74,5 @@ TEST_SUITE(FF_TEST_SUITE) { make_input( SumDegree{degree}, DiscardCopyDegree{1_p}, 1_p, 1_p, 1_p))); } - } } diff --git a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h index 6c24d4c1e1..a2afdb7914 100644 --- a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h +++ b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h @@ -15,22 +15,24 @@ MappedOperatorTaskGroup ParallelComputationGraph pcg_from_mpcg(MappedParallelComputationGraph const &); -parallel_layer_guid_t mpcg_get_source_layer(MappedParallelComputationGraph const &, - parallel_tensor_guid_t const &); +parallel_layer_guid_t + mpcg_get_source_layer(MappedParallelComputationGraph const &, + parallel_tensor_guid_t const &); PCGOperatorAttrs mpcg_get_pcg_op_attrs(MappedParallelComputationGraph const &, parallel_layer_guid_t const &); -ParallelTensorAttrs mpcg_get_parallel_tensor_attrs(MappedParallelComputationGraph const &, - parallel_tensor_guid_t const &); +ParallelTensorAttrs + mpcg_get_parallel_tensor_attrs(MappedParallelComputationGraph const &, + parallel_tensor_guid_t const &); std::unordered_map - mpcg_get_incoming_edges(MappedParallelComputationGraph const &, - parallel_layer_guid_t const &); + mpcg_get_incoming_edges(MappedParallelComputationGraph const &, + parallel_layer_guid_t const &); std::unordered_set - mpcg_get_outgoing_edges(MappedParallelComputationGraph const &, - parallel_layer_guid_t const &); + mpcg_get_outgoing_edges(MappedParallelComputationGraph const &, + parallel_layer_guid_t const &); ManyToOne mpcg_get_incoming_tensors(MappedParallelComputationGraph const &, @@ -38,7 +40,7 @@ ManyToOne bidict mpcg_get_outgoing_tensors(MappedParallelComputationGraph const &, - parallel_layer_guid_t const &); + parallel_layer_guid_t const &); std::unordered_set mpcg_get_edges(MappedParallelComputationGraph const &); diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h index 1b2d5a0b67..9764e40627 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h @@ -10,8 +10,8 @@ #include "pcg/parallel_computation_graph/parallel_layer_added_result.dtg.h" #include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" #include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" -#include #include "pcg/parallel_computation_graph/parallel_tensor_use_t.dtg.h" +#include namespace FlexFlow { diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.h index 88f1512149..f5e5575632 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.h @@ -1,12 +1,13 @@ #ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_TENSOR_USE_T_H #define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_TENSOR_USE_T_H -#include "pcg/parallel_computation_graph/parallel_tensor_use_t.dtg.h" #include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_use_t.dtg.h" namespace FlexFlow { -parallel_layer_guid_t parallel_tensor_use_get_layer(parallel_tensor_use_t const &); +parallel_layer_guid_t + parallel_tensor_use_get_layer(parallel_tensor_use_t const &); TensorSlotName parallel_tensor_use_get_slot(parallel_tensor_use_t const &); } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc index 3b996ccdab..fc1dff504b 100644 --- a/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc @@ -2,13 +2,13 @@ #include "op-attrs/pcg_operator_attrs.h" #include "pcg/mapped_parallel_computation_graph/mapped_parallel_layer_attrs.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "utils/bidict/algorithms/bidict_from_map.h" #include "utils/bidict/algorithms/transform_keys.h" #include "utils/containers/transform.h" #include "utils/graph/kwarg_dataflow_graph/algorithms/find_isomorphism_between_kwarg_dataflow_graphs.h" #include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_kwarg_dataflow_graph_view_as_dot.h" #include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/materialize_labelled_kwarg_dataflow_graph_view.h" #include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/rewrite_labelled_kwarg_dataflow_graph_node_labels.h" -#include "utils/bidict/algorithms/bidict_from_map.h" #include "utils/many_to_one/many_to_one_from_map.h" namespace FlexFlow { @@ -48,63 +48,56 @@ ParallelComputationGraph }; } -parallel_layer_guid_t mpcg_get_source_layer(MappedParallelComputationGraph const &mpcg, - parallel_tensor_guid_t const &t) -{ +parallel_layer_guid_t + mpcg_get_source_layer(MappedParallelComputationGraph const &mpcg, + parallel_tensor_guid_t const &t) { return get_source_layer(pcg_from_mpcg(mpcg), t); } -PCGOperatorAttrs mpcg_get_pcg_op_attrs(MappedParallelComputationGraph const &mpcg, - parallel_layer_guid_t const &l) -{ +PCGOperatorAttrs + mpcg_get_pcg_op_attrs(MappedParallelComputationGraph const &mpcg, + parallel_layer_guid_t const &l) { return pcg_get_op_attrs(pcg_from_mpcg(mpcg), l); } -ParallelTensorAttrs mpcg_get_parallel_tensor_attrs(MappedParallelComputationGraph const &mpcg, - parallel_tensor_guid_t const &t) -{ +ParallelTensorAttrs + mpcg_get_parallel_tensor_attrs(MappedParallelComputationGraph const &mpcg, + parallel_tensor_guid_t const &t) { return get_parallel_tensor_attrs(pcg_from_mpcg(mpcg), t); } std::unordered_map - mpcg_get_incoming_edges(MappedParallelComputationGraph const &mpcg, - parallel_layer_guid_t const &l) -{ + mpcg_get_incoming_edges(MappedParallelComputationGraph const &mpcg, + parallel_layer_guid_t const &l) { return get_incoming_edges(pcg_from_mpcg(mpcg), l); } std::unordered_set - mpcg_get_outgoing_edges(MappedParallelComputationGraph const &mpcg, - parallel_layer_guid_t const &l) -{ + mpcg_get_outgoing_edges(MappedParallelComputationGraph const &mpcg, + parallel_layer_guid_t const &l) { return get_outgoing_edges(pcg_from_mpcg(mpcg), l); } ManyToOne mpcg_get_incoming_tensors(MappedParallelComputationGraph const &mpcg, - parallel_layer_guid_t const &l) -{ + parallel_layer_guid_t const &l) { return many_to_one_from_map(get_incoming_tensors(pcg_from_mpcg(mpcg), l)); } - bidict mpcg_get_outgoing_tensors(MappedParallelComputationGraph const &mpcg, - parallel_layer_guid_t const &l) -{ + parallel_layer_guid_t const &l) { return bidict_from_map(get_outgoing_tensors(pcg_from_mpcg(mpcg), l)); } std::unordered_set - mpcg_get_edges(MappedParallelComputationGraph const &mpcg) -{ + mpcg_get_edges(MappedParallelComputationGraph const &mpcg) { return get_edges(pcg_from_mpcg(mpcg)); } std::unordered_set mpcg_get_parallel_tensor_uses(MappedParallelComputationGraph const &mpcg, - parallel_tensor_guid_t const &t) -{ + parallel_tensor_guid_t const &t) { return pcg_get_parallel_tensor_uses(pcg_from_mpcg(mpcg), t); } diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index 2c5197242d..5098cadafe 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -28,6 +28,7 @@ #include "utils/graph/kwarg_dataflow_graph/algorithms/find_isomorphism_between_kwarg_dataflow_graphs.h" #include "utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_kwarg_dataflow_outputs_for_node.h" #include "utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_edges_from_node_to_node.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.h" #include "utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_edges_for_node.h" #include "utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_outputs_for_node.h" #include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_kwarg_dataflow_graph_view_as_dot.h" @@ -36,7 +37,6 @@ #include "utils/graph/node/node.dtg.h" #include "utils/record_formatter.h" #include -#include "utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.h" namespace FlexFlow { @@ -209,18 +209,15 @@ std::unordered_map std::unordered_set pcg_get_parallel_tensor_uses(ParallelComputationGraph const &pcg, - parallel_tensor_guid_t const &t) -{ + parallel_tensor_guid_t const &t) { std::unordered_set> raw_uses = - get_kwarg_dataflow_value_uses(pcg.raw_graph, - t.raw_graph_output); + get_kwarg_dataflow_value_uses(pcg.raw_graph, t.raw_graph_output); return transform(raw_uses, [](KwargDataflowInput const &i) { return parallel_tensor_use_t{i}; }); } - std::unordered_set get_initial_layers(ParallelComputationGraph const &pcg) { std::unordered_set raw_sources = get_initial_nodes(pcg.raw_graph); diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_use_t.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_use_t.cc index e93341d312..71a9cadf1c 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_use_t.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_use_t.cc @@ -2,7 +2,8 @@ namespace FlexFlow { -parallel_layer_guid_t parallel_tensor_use_get_layer(parallel_tensor_use_t const &u) { +parallel_layer_guid_t + parallel_tensor_use_get_layer(parallel_tensor_use_t const &u) { return parallel_layer_guid_t{u.raw_dataflow_input.node}; } diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h index 26c98e915c..2a3dc8bbb8 100644 --- a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h @@ -49,7 +49,7 @@ std::unordered_set get_subgraph_outgoing_edges( std::unordered_set get_open_parallel_tensor_uses(SubParallelComputationGraph const &, - open_parallel_tensor_guid_t const &); + open_parallel_tensor_guid_t const &); SubParallelComputationGraphData get_sub_pcg_data(SubParallelComputationGraph const &); diff --git a/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc b/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc index a56555550f..f2686f7cf7 100644 --- a/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc +++ b/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc @@ -109,9 +109,10 @@ SubParallelComputationGraph apply_substitution_from_output_result( input_parallel_tensor_guid_t output_graph_input = output_expr_to_result_sub_pcg_mapping.input_mapping.at_r( output_expr_input); - std::unordered_set uses = get_open_parallel_tensor_uses( - substitution_output_graph, - open_parallel_tensor_guid_from_input(output_graph_input)); + std::unordered_set uses = + get_open_parallel_tensor_uses( + substitution_output_graph, + open_parallel_tensor_guid_from_input(output_graph_input)); for (parallel_tensor_use_t const &use : uses) { SubParallelComputationGraphEdge new_edge = subpcg_edge_from_tensor_and_use(base_graph_tensor, use); diff --git a/lib/substitutions/src/substitutions/pcg_pattern_match.cc b/lib/substitutions/src/substitutions/pcg_pattern_match.cc index dbd968d476..85a0493e33 100644 --- a/lib/substitutions/src/substitutions/pcg_pattern_match.cc +++ b/lib/substitutions/src/substitutions/pcg_pattern_match.cc @@ -4,8 +4,8 @@ #include "substitutions/unlabelled/unlabelled_graph_pattern.h" #include "utils/bidict/algorithms/bidict_from_keys_and_values.h" #include "utils/bidict/algorithms/bidict_from_map.h" -#include "utils/bidict/algorithms/exhaustive_relational_join.h" #include "utils/bidict/algorithms/binary_merge_disjoint_bidicts.h" +#include "utils/bidict/algorithms/exhaustive_relational_join.h" #include "utils/bidict/algorithms/transform_values.h" #include "utils/containers/is_subseteq_of.h" #include "utils/containers/map_values.h" diff --git a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc index 990975bff9..c0c05ad5b1 100644 --- a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc +++ b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc @@ -132,7 +132,7 @@ std::unordered_set get_subgraph_incoming_edges( std::unordered_set get_open_parallel_tensor_uses(SubParallelComputationGraph const &spcg, - open_parallel_tensor_guid_t const &t) { + open_parallel_tensor_guid_t const &t) { std::unordered_set> raw_uses = get_open_kwarg_dataflow_value_uses(spcg.raw_graph, t.raw_open_dataflow_value); diff --git a/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.h b/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.h index bb8ca4f840..9caea8c341 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.h +++ b/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.h @@ -1,12 +1,13 @@ #ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_TRAINING_OPERATION_ATTRS_H #define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_TRAINING_OPERATION_ATTRS_H -#include "task-spec/dynamic_graph/training_operation_attrs.dtg.h" #include "op-attrs/operator_type.dtg.h" +#include "task-spec/dynamic_graph/training_operation_attrs.dtg.h" namespace FlexFlow { -bool training_op_attrs_has_op_type(TrainingOperationAttrs const &, OperatorType); +bool training_op_attrs_has_op_type(TrainingOperationAttrs const &, + OperatorType); } // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc index 664c615a90..7a149787b9 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc @@ -5,19 +5,19 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" #include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_use_t.h" #include "task-spec/dynamic_graph/dynamic_layer_guid_t.dtg.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_role.h" +#include "utils/bidict/algorithms/merge_disjoint_bidicts.h" #include "utils/containers/generate_map.h" #include "utils/containers/get_only.h" +#include "utils/containers/map_keys_and_values.h" +#include "utils/containers/require_only_key.h" +#include "utils/containers/transform_pairs.h" #include #include #include -#include "pcg/parallel_computation_graph/parallel_tensor_use_t.h" -#include "utils/containers/require_only_key.h" -#include "utils/bidict/algorithms/merge_disjoint_bidicts.h" -#include "utils/containers/map_keys_and_values.h" -#include "utils/containers/transform_pairs.h" namespace FlexFlow { @@ -35,8 +35,8 @@ static bidict TensorSlotName producer_slot = get_src_layer_output_slot_name(input_edge); return get_tensor_bindings_for_slot_name( - /*task_group=*/mpcg_get_mapping_for_layer(mpcg, producer_layer), - /*slot_name=*/producer_slot); + /*task_group=*/mpcg_get_mapping_for_layer(mpcg, producer_layer), + /*slot_name=*/producer_slot); } static bidict @@ -44,26 +44,29 @@ static bidict MappedParallelComputationGraph const &mpcg, parallel_tensor_guid_t const &output_tensor_guid) { - std::unordered_set consumers = mpcg_get_parallel_tensor_uses(mpcg, output_tensor_guid); + std::unordered_set consumers = + mpcg_get_parallel_tensor_uses(mpcg, output_tensor_guid); ASSERT(!consumers.empty()); // union all consumer bindings — each consumer shard maps to a distinct // (discard_copy, machine) pair since replicas are always on different machines bidict result = - merge_disjoint_bidicts( - transform(consumers, - [&](parallel_tensor_use_t const &use) - -> bidict - { - parallel_layer_guid_t consumer_layer = parallel_tensor_use_get_layer(use); - TensorSlotName slot_name = parallel_tensor_use_get_slot(use); - - MappedOperatorTaskGroup consumer_mapping = mpcg_get_mapping_for_layer(mpcg, consumer_layer); - bidict binding = - get_tensor_bindings_for_slot_name(consumer_mapping, slot_name); - - return binding; - })); + merge_disjoint_bidicts(transform( + consumers, + [&](parallel_tensor_use_t const &use) + -> bidict { + parallel_layer_guid_t consumer_layer = + parallel_tensor_use_get_layer(use); + TensorSlotName slot_name = parallel_tensor_use_get_slot(use); + + MappedOperatorTaskGroup consumer_mapping = + mpcg_get_mapping_for_layer(mpcg, consumer_layer); + bidict + binding = get_tensor_bindings_for_slot_name(consumer_mapping, + slot_name); + + return binding; + })); return result; } @@ -73,15 +76,19 @@ static DynamicNodeInvocation ReplicateAttrs const &attrs, MappedParallelComputationGraph const &mpcg) { - ManyToOne incoming = mpcg_get_incoming_tensors(mpcg, layer); + ManyToOne incoming = + mpcg_get_incoming_tensors(mpcg, layer); TensorSlotName input_slot_name = TensorSlotName::INPUT; - parallel_tensor_guid_t input_tensor_guid = require_only_key(incoming.l_to_r(), input_slot_name); + parallel_tensor_guid_t input_tensor_guid = + require_only_key(incoming.l_to_r(), input_slot_name); ParallelTensorAttrs input_attrs = mpcg_get_parallel_tensor_attrs(mpcg, input_tensor_guid); - bidict outgoing = mpcg_get_outgoing_tensors(mpcg, layer); + bidict outgoing = + mpcg_get_outgoing_tensors(mpcg, layer); TensorSlotName output_slot_name = TensorSlotName::OUTPUT; - parallel_tensor_guid_t output_tensor_guid = require_only_key(outgoing.l_to_r(), output_slot_name); + parallel_tensor_guid_t output_tensor_guid = + require_only_key(outgoing.l_to_r(), output_slot_name); ParallelTensorAttrs output_attrs = mpcg_get_parallel_tensor_attrs(mpcg, output_tensor_guid); @@ -117,17 +124,18 @@ static DynamicNodeInvocation DynamicNodeInvocation invocation_node{ /*inputs=*/{ - { - DynamicTensorSlot{input_slot_name, std::nullopt}, - input_value, - }, + { + DynamicTensorSlot{input_slot_name, std::nullopt}, + input_value, + }, }, /*node_attrs=*/node_attrs, - /*outputs=*/{ - { - DynamicTensorSlot{output_slot_name, std::nullopt}, - output_value, - }, + /*outputs=*/ + { + { + DynamicTensorSlot{output_slot_name, std::nullopt}, + output_value, + }, }, }; @@ -139,9 +147,9 @@ DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mapped_pcg( ParallelComputationGraph pcg = pcg_from_mpcg(mpcg); - auto mk_invocation = [&](parallel_layer_guid_t layer, ParallelLayerAttrs const &attrs) - -> DynamicNodeInvocation - { + auto mk_invocation = + [&](parallel_layer_guid_t layer, + ParallelLayerAttrs const &attrs) -> DynamicNodeInvocation { if (attrs.op_attrs.is_parallel_replicate()) { // build replicate invocation DynamicNodeInvocation repl_inv = build_replicate_invocation( @@ -159,50 +167,45 @@ DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mapped_pcg( auto mk_slot = [](TensorSlotName const &slot_name) -> DynamicTensorSlot { return DynamicTensorSlot{ - /*slot_name=*/slot_name, - /*slot_tensor_role=*/std::nullopt, + /*slot_name=*/slot_name, + /*slot_tensor_role=*/std::nullopt, }; }; - auto mk_value_attrs = [&](parallel_tensor_guid_t const &tensor) -> DynamicValueAttrs - { - ParallelTensorAttrs attrs = - get_parallel_tensor_attrs(pcg, tensor); + auto mk_value_attrs = + [&](parallel_tensor_guid_t const &tensor) -> DynamicValueAttrs { + ParallelTensorAttrs attrs = get_parallel_tensor_attrs(pcg, tensor); return DynamicValueAttrs{ - /*tensor_guid=*/dynamic_tensor_guid_t{tensor}, - /*parallel_tensor_shape=*/attrs.shape, - /*shard_coord=*/std::nullopt, - /*mapping=*/std::nullopt, - /*accessor=*/std::nullopt, - /*role=*/std::nullopt, + /*tensor_guid=*/dynamic_tensor_guid_t{tensor}, + /*parallel_tensor_shape=*/attrs.shape, + /*shard_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*accessor=*/std::nullopt, + /*role=*/std::nullopt, }; }; std::unordered_map result_inputs = - map_keys_and_values(get_incoming_tensors(pcg, layer), - mk_slot, - mk_value_attrs); + map_keys_and_values( + get_incoming_tensors(pcg, layer), mk_slot, mk_value_attrs); std::unordered_map result_outputs = - map_keys_and_values(get_outgoing_tensors(pcg, layer), - mk_slot, - mk_value_attrs); + map_keys_and_values( + get_outgoing_tensors(pcg, layer), mk_slot, mk_value_attrs); DynamicNodeInvocation invocation = DynamicNodeInvocation{ - /*inputs=*/result_inputs, - /*node_attrs=*/result_attrs, - /*outputs=*/result_outputs, + /*inputs=*/result_inputs, + /*node_attrs=*/result_attrs, + /*outputs=*/result_outputs, }; return invocation; }; }; - return dynamic_open_dataflow_graph_from_invocation_set( - transform_pairs( - unordered_set_of(get_parallel_layer_attrs_mapping(pcg)), - mk_invocation)); + return dynamic_open_dataflow_graph_from_invocation_set(transform_pairs( + unordered_set_of(get_parallel_layer_attrs_mapping(pcg)), mk_invocation)); } } // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc b/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc index f4960fe67a..64fe2df0be 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc @@ -1,11 +1,11 @@ #include "task-spec/dynamic_graph/pass_expansion.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_role.h" +#include "task-spec/dynamic_graph/training_operation_attrs.h" #include "utils/containers/are_all_same.h" #include "utils/containers/get_only.h" #include "utils/containers/merge_disjoint_maps.h" #include "utils/containers/transform.h" -#include "task-spec/dynamic_graph/training_operation_attrs.h" namespace FlexFlow { @@ -84,7 +84,8 @@ DynamicNodeInvocation perform_fwd_pass_expansion_for_invocation( DynamicNodeInvocation perform_bwd_pass_expansion_for_invocation( DynamicNodeInvocation const &invocation) { - TrainingOperationAttrs op_attrs = assert_unwrap(invocation.node_attrs.op_attrs); + TrainingOperationAttrs op_attrs = + assert_unwrap(invocation.node_attrs.op_attrs); auto to_fwd = [](DynamicTensorSlot const &k, DynamicValueAttrs const &v) { return std::pair{ @@ -106,15 +107,16 @@ DynamicNodeInvocation perform_bwd_pass_expansion_for_invocation( DynamicNodeInvocation bwd{ /*inputs=*/{ - to_fwd(output_slot, output), - to_grad(output_slot, output), + to_fwd(output_slot, output), + to_grad(output_slot, output), }, /*node_attrs=*/ pass_expand_node(invocation.node_attrs, DynamicTaskType::BWD), - /*outputs=*/{ - to_grad(input_slot, input), + /*outputs=*/ + { + to_grad(input_slot, input), }, - }; + }; return bwd; } else { diff --git a/lib/task-spec/src/task-spec/dynamic_graph/training_operation_attrs.cc b/lib/task-spec/src/task-spec/dynamic_graph/training_operation_attrs.cc index d1452242ca..a9be225ff5 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/training_operation_attrs.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/training_operation_attrs.cc @@ -4,17 +4,14 @@ namespace FlexFlow { -bool training_op_attrs_has_op_type(TrainingOperationAttrs const &op_attrs, OperatorType op_type) { - return op_attrs.visit(overload { - [&](PCGOperatorAttrs const &pcg_op_attrs) -> bool { - return pcg_op_attrs_get_op_type(pcg_op_attrs) == op_type; - }, - [](LossAttrs const &) -> bool { - return false; - }, - [](CopyAttrs const &) -> bool { - return false; - }, +bool training_op_attrs_has_op_type(TrainingOperationAttrs const &op_attrs, + OperatorType op_type) { + return op_attrs.visit(overload{ + [&](PCGOperatorAttrs const &pcg_op_attrs) -> bool { + return pcg_op_attrs_get_op_type(pcg_op_attrs) == op_type; + }, + [](LossAttrs const &) -> bool { return false; }, + [](CopyAttrs const &) -> bool { return false; }, }); } diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc index ed22a8cbde..bf88d5ec38 100644 --- a/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc @@ -1,8 +1,8 @@ #include "task-spec/dynamic_graph/pass_expansion.h" +#include "op-attrs/ops/element_unary.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_role.h" #include -#include "op-attrs/ops/element_unary.h" using namespace ::FlexFlow; @@ -37,18 +37,17 @@ TEST_SUITE(FF_TEST_SUITE) { dynamic_layer_guid_t layer_guid{parallel_layer_guid_t{Node{20}}}; - TrainingOperationAttrs op_attrs = - TrainingOperationAttrs{ + TrainingOperationAttrs op_attrs = TrainingOperationAttrs{ PCGOperatorAttrs{ - LinearAttrs{ - /*out_channels=*/8_p, - /*use_bias=*/true, - /*data_type=*/DataType::FLOAT, - /*activation=*/std::nullopt, - /*regularizer=*/std::nullopt, - }, + LinearAttrs{ + /*out_channels=*/8_p, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*activation=*/std::nullopt, + /*regularizer=*/std::nullopt, + }, }, - }; + }; DynamicNodeInvocation invocation = [&]() -> DynamicNodeInvocation { DynamicValueAttrs v1 = mk_value_attrs(0, std::nullopt); @@ -157,18 +156,17 @@ TEST_SUITE(FF_TEST_SUITE) { DynamicValueAttrs v3_grad = mk_value_attrs(2, grad_role); SUBCASE("normal operator") { - TrainingOperationAttrs op_attrs = - TrainingOperationAttrs{ + TrainingOperationAttrs op_attrs = TrainingOperationAttrs{ PCGOperatorAttrs{ - LinearAttrs{ - /*out_channels=*/8_p, - /*use_bias=*/true, - /*data_type=*/DataType::FLOAT, - /*activation=*/std::nullopt, - /*regularizer=*/std::nullopt, - }, + LinearAttrs{ + /*out_channels=*/8_p, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*activation=*/std::nullopt, + /*regularizer=*/std::nullopt, + }, }, - }; + }; DynamicNodeInvocation invocation = [&]() -> DynamicNodeInvocation { return DynamicNodeInvocation{ @@ -227,14 +225,13 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("replicate operator optimization") { - TrainingOperationAttrs op_attrs = - TrainingOperationAttrs{ + TrainingOperationAttrs op_attrs = TrainingOperationAttrs{ PCGOperatorAttrs{ - ReplicateAttrs{ - /*replicate_degree=*/2_p, - }, + ReplicateAttrs{ + /*replicate_degree=*/2_p, + }, }, - }; + }; DynamicNodeInvocation invocation = [&]() -> DynamicNodeInvocation { return DynamicNodeInvocation{ @@ -262,7 +259,8 @@ TEST_SUITE(FF_TEST_SUITE) { DynamicNodeInvocation correct = [&]() -> DynamicNodeInvocation { DynamicTensorRole fwd_role = DynamicTensorRole{FwbTensorType::FORWARD}; - DynamicTensorRole grad_role = DynamicTensorRole{FwbTensorType::GRADIENT}; + DynamicTensorRole grad_role = + DynamicTensorRole{FwbTensorType::GRADIENT}; return DynamicNodeInvocation{ /*inputs=*/{ @@ -324,28 +322,27 @@ TEST_SUITE(FF_TEST_SUITE) { }; TrainingOperationAttrs input_op_attrs = TrainingOperationAttrs{ - PCGOperatorAttrs{ - InputAttrs{ - TensorShape{ - TensorDims{ - FFOrdered{ - 4_p, - 8_p, - }, + PCGOperatorAttrs{ + InputAttrs{ + TensorShape{ + TensorDims{ + FFOrdered{ + 4_p, + 8_p, + }, + }, + DataType::FLOAT, + }, }, - DataType::FLOAT, - }, }, - }, }; TrainingOperationAttrs relu_op_attrs = TrainingOperationAttrs{ - PCGOperatorAttrs{ - make_relu_attrs(), - }, + PCGOperatorAttrs{ + make_relu_attrs(), + }, }; - DynamicOpenDataflowGraph input = [&]() -> DynamicOpenDataflowGraph { DynamicNodeAttrs n1 = mk_node_attrs(10, input_op_attrs, std::nullopt); DynamicNodeAttrs n2 = mk_node_attrs(11, relu_op_attrs, std::nullopt); @@ -396,10 +393,14 @@ TEST_SUITE(FF_TEST_SUITE) { DynamicOpenDataflowGraph result = perform_pass_expansion(input); DynamicOpenDataflowGraph correct = [&]() -> DynamicOpenDataflowGraph { - DynamicNodeAttrs n1_fwd = mk_node_attrs(10, input_op_attrs, DynamicTaskType::FWD); - DynamicNodeAttrs n2_fwd = mk_node_attrs(11, relu_op_attrs, DynamicTaskType::FWD); - DynamicNodeAttrs n1_bwd = mk_node_attrs(10, input_op_attrs, DynamicTaskType::BWD); - DynamicNodeAttrs n2_bwd = mk_node_attrs(11, relu_op_attrs, DynamicTaskType::BWD); + DynamicNodeAttrs n1_fwd = + mk_node_attrs(10, input_op_attrs, DynamicTaskType::FWD); + DynamicNodeAttrs n2_fwd = + mk_node_attrs(11, relu_op_attrs, DynamicTaskType::FWD); + DynamicNodeAttrs n1_bwd = + mk_node_attrs(10, input_op_attrs, DynamicTaskType::BWD); + DynamicNodeAttrs n2_bwd = + mk_node_attrs(11, relu_op_attrs, DynamicTaskType::BWD); DynamicValueAttrs v1_activation = mk_value_attrs(0, mk_dynamic_tensor_role_fwd()); diff --git a/lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h b/lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h index f2104fd113..0c944bb9bd 100644 --- a/lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h +++ b/lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_MERGE_DISJOINT_BIDICTS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_MERGE_DISJOINT_BIDICTS_H -#include "utils/containers/foldl.h" #include "utils/bidict/algorithms/binary_merge_disjoint_bidicts.h" +#include "utils/containers/foldl.h" namespace FlexFlow { @@ -13,8 +13,7 @@ bidict merge_disjoint_bidicts(C const &c) { bidict empty = {}; return foldl(c, /*init=*/empty, - [](bidict const &lhs, - bidict const &rhs) { + [](bidict const &lhs, bidict const &rhs) { return binary_merge_disjoint_bidicts(lhs, rhs); }); } diff --git a/lib/utils/include/utils/containers/transform_pairs.h b/lib/utils/include/utils/containers/transform_pairs.h index c01b50554f..3e421ea445 100644 --- a/lib/utils/include/utils/containers/transform_pairs.h +++ b/lib/utils/include/utils/containers/transform_pairs.h @@ -21,7 +21,8 @@ template > -std::unordered_set transform_pairs(std::unordered_set> const &c, F &&f) { +std::unordered_set + transform_pairs(std::unordered_set> const &c, F &&f) { auto ff = [&](std::pair const &p) -> Out { return f(p.first, p.second); }; diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.h index b5557e9e49..52c225d157 100644 --- a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.h +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.h @@ -7,25 +7,20 @@ namespace FlexFlow { template std::unordered_set> - get_kwarg_dataflow_value_uses( - KwargDataflowGraphView const &g, - KwargDataflowOutput const &v) { - - KwargDataflowEdgeQuery query = - KwargDataflowEdgeQuery{ - /*src_nodes=*/query_set::match_single_value(v.node), - /*src_slots=*/query_set::match_single_value(v.slot_name), - /*dst_nodes=*/query_set::matchall(), - /*dst_slots=*/query_set::matchall(), - }; - - std::unordered_set> edges = - g.query_edges(query); - - return transform( - edges, [&](KwargDataflowEdge const &e) { - return e.dst; - }); + get_kwarg_dataflow_value_uses(KwargDataflowGraphView const &g, + KwargDataflowOutput const &v) { + + KwargDataflowEdgeQuery query = KwargDataflowEdgeQuery{ + /*src_nodes=*/query_set::match_single_value(v.node), + /*src_slots=*/query_set::match_single_value(v.slot_name), + /*dst_nodes=*/query_set::matchall(), + /*dst_slots=*/query_set::matchall(), + }; + + std::unordered_set> edges = g.query_edges(query); + + return transform(edges, + [&](KwargDataflowEdge const &e) { return e.dst; }); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/many_to_one/many_to_one.h b/lib/utils/include/utils/many_to_one/many_to_one.h index c73f696172..2d078eb304 100644 --- a/lib/utils/include/utils/many_to_one/many_to_one.h +++ b/lib/utils/include/utils/many_to_one/many_to_one.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_MANY_TO_ONE_H #include "utils/containers/keys.h" +#include "utils/containers/require_same.h" #include "utils/containers/try_at.h" #include "utils/containers/unordered_set_of.h" #include "utils/containers/values.h" @@ -19,7 +20,6 @@ #include #include #include -#include "utils/containers/require_same.h" namespace FlexFlow { diff --git a/lib/utils/include/utils/one_to_many/one_to_many.h b/lib/utils/include/utils/one_to_many/one_to_many.h index 7b725fdec1..5492ff3f78 100644 --- a/lib/utils/include/utils/one_to_many/one_to_many.h +++ b/lib/utils/include/utils/one_to_many/one_to_many.h @@ -4,6 +4,7 @@ #include "utils/containers/generate_map.h" #include "utils/containers/items.h" #include "utils/containers/keys.h" +#include "utils/containers/require_same.h" #include "utils/containers/transform.h" #include "utils/containers/try_at.h" #include "utils/containers/unordered_set_of.h" @@ -23,7 +24,6 @@ #include #include #include -#include "utils/containers/require_same.h" namespace FlexFlow { diff --git a/lib/utils/src/utils/bidict/algorithms/binary_merge_disjoint_bidicts.cc b/lib/utils/src/utils/bidict/algorithms/binary_merge_disjoint_bidicts.cc index 8650de44f6..13a1bcd968 100644 --- a/lib/utils/src/utils/bidict/algorithms/binary_merge_disjoint_bidicts.cc +++ b/lib/utils/src/utils/bidict/algorithms/binary_merge_disjoint_bidicts.cc @@ -6,7 +6,7 @@ namespace FlexFlow { using K = value_type<0>; using V = value_type<1>; -template - bidict binary_merge_disjoint_bidicts(bidict const &, bidict const &); +template bidict binary_merge_disjoint_bidicts(bidict const &, + bidict const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/transform_pairs.cc b/lib/utils/src/utils/containers/transform_pairs.cc index 241f1ad425..4afda936e4 100644 --- a/lib/utils/src/utils/containers/transform_pairs.cc +++ b/lib/utils/src/utils/containers/transform_pairs.cc @@ -8,10 +8,10 @@ using R = value_type<1>; using Out = value_type<2>; using F = std::function; -template - std::vector transform_pairs(std::vector> const &, F &&); +template std::vector transform_pairs(std::vector> const &, + F &&); -template - std::unordered_set transform_pairs(std::unordered_set> const &, F &&); +template std::unordered_set + transform_pairs(std::unordered_set> const &, F &&); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.cc b/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.cc index 2e42863e53..b1d2988223 100644 --- a/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.cc +++ b/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.cc @@ -5,10 +5,8 @@ namespace FlexFlow { using SlotName = ordered_value_type<0>; -template - std::unordered_set> - get_kwarg_dataflow_value_uses( - KwargDataflowGraphView const &, - KwargDataflowOutput const &); +template std::unordered_set> + get_kwarg_dataflow_value_uses(KwargDataflowGraphView const &, + KwargDataflowOutput const &); } // namespace FlexFlow From f2b075482f46dff9d5c66fd4a7616bc425f23421 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 21 May 2026 12:15:38 -0700 Subject: [PATCH 09/19] Format Realm. --- .../src/realm-execution/pcg_instance.cc | 31 ++- .../src/realm-execution/test_op_replicate.cc | 185 +++++++++--------- 2 files changed, 107 insertions(+), 109 deletions(-) diff --git a/lib/realm-execution/src/realm-execution/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance.cc index f2edac7f88..332669a9dc 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance.cc @@ -217,14 +217,11 @@ static Realm::Event spawn_dynamic_node_invocation( }; auto issue_replicate_bwd = [&]() { - - DynamicValueAttrs output_grad = get_only( - values( - filter_keys( - invocation.inputs, - [](DynamicTensorSlot const &s) -> bool { - return s.slot_tensor_role == DynamicTensorRole{FwbTensorType::GRADIENT}; - }))); + DynamicValueAttrs output_grad = get_only(values( + filter_keys(invocation.inputs, [](DynamicTensorSlot const &s) -> bool { + return s.slot_tensor_role == + DynamicTensorRole{FwbTensorType::GRADIENT}; + }))); DynamicValueAttrs input_grad = get_only(values(invocation.outputs)); @@ -246,15 +243,15 @@ static Realm::Event spawn_dynamic_node_invocation( tensor_instance_backing.backing.at(replica_key).first; e = ctx.issue_copy( - /*src_shape=*/assert_unwrap(output_grad.parallel_tensor_shape), - /*src_inst=*/src_inst, - /*dst_shape=*/assert_unwrap(input_grad.parallel_tensor_shape), - /*dst_inst=*/dst_inst, - /*requests=*/Realm::ProfilingRequestSet{}, - /*wait_on=*/e, - /*priority=*/0, - /*redop_id=*/redop_id, - /*exlusive=*/false); + /*src_shape=*/assert_unwrap(output_grad.parallel_tensor_shape), + /*src_inst=*/src_inst, + /*dst_shape=*/assert_unwrap(input_grad.parallel_tensor_shape), + /*dst_inst=*/dst_inst, + /*requests=*/Realm::ProfilingRequestSet{}, + /*wait_on=*/e, + /*priority=*/0, + /*redop_id=*/redop_id, + /*exlusive=*/false); } return e; }; diff --git a/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc b/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc index 2523cae798..46d29e2bef 100644 --- a/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc +++ b/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc @@ -13,6 +13,7 @@ #include "op-attrs/tensor_slot_name.dtg.h" #include "pcg/device_type.dtg.h" #include "pcg/machine_space_coordinate.dtg.h" +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h" #include "pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" @@ -27,7 +28,6 @@ #include "test/utils/doctest/check_kv.h" #include "utils/containers/require_only_key.h" #include -#include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h" namespace test { @@ -49,7 +49,8 @@ static bool did_loss_decrease(GenericTensorAccessorR const &first_epoch, compare_tensor_accessors_le(last_epoch, first_epoch, allocator)); } -MappedParallelComputationGraph make_test_mpcg_for_device_type(DeviceType device_type) { +MappedParallelComputationGraph + make_test_mpcg_for_device_type(DeviceType device_type) { positive_int batch_size = 10_p; positive_int data_dim = 16_p; positive_int hidden_dim = 32_p; @@ -63,8 +64,8 @@ MappedParallelComputationGraph make_test_mpcg_for_device_type(DeviceType device_ ParallelComputationGraph pcg = empty_parallel_computation_graph(); - TensorShape input_tensor_shape = TensorShape{ - TensorDims{FFOrdered{batch_size, data_dim}}, DataType::FLOAT}; + TensorShape input_tensor_shape = + TensorShape{TensorDims{FFOrdered{batch_size, data_dim}}, DataType::FLOAT}; ParallelLayerAddedResult inputs_layer = pcg_add_input_layer(pcg, input_tensor_shape); @@ -144,91 +145,92 @@ MappedParallelComputationGraph make_test_mpcg_for_device_type(DeviceType device_ /*discard_copy_component=*/1_n, /*shard_component=*/FFOrdered{0_n}}; - MappedParallelComputationGraph mpcg = mapped_pcg_from_pcg_and_mapped_op_task_groups( - /*pcg=*/pcg, - /*mapped_op_task_groups=*/{ - { - inputs_layer.parallel_layer, - MappedOperatorTaskGroup{ - { + MappedParallelComputationGraph mpcg = + mapped_pcg_from_pcg_and_mapped_op_task_groups( + /*pcg=*/pcg, + /*mapped_op_task_groups=*/{ { - cpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::OUTPUT, tensor_coord0}, - }}, + inputs_layer.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + }, + }, }, - }, - }, - }, - { - inputs_layer_2.parallel_layer, - MappedOperatorTaskGroup{ - { { - cpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::OUTPUT, tensor_coord0}, - }}, + inputs_layer_2.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + }, + }, }, - }, - }, - }, - { - add_operator_1.parallel_layer, - MappedOperatorTaskGroup{ - { { - cpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::LHS_INPUT, tensor_coord0}, - {TensorSlotName::RHS_INPUT, tensor_coord0}, - {TensorSlotName::OUTPUT, tensor_coord0}, - }}, + add_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::LHS_INPUT, tensor_coord0}, + {TensorSlotName::RHS_INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + }, + }, }, - }, - }, - }, - { - repl_operator_1.parallel_layer, - MappedOperatorTaskGroup{ - { { - cpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::OUTPUT, tensor_coord0}, - }}, + repl_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + { + cpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord1}, + }}, + }, + }, + }, }, { - cpu1, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::OUTPUT, tensor_coord1}, - }}, + relu_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + { + cpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord1}, + {TensorSlotName::OUTPUT, tensor_coord1}, + }}, + }, + }, + }, }, - }, - }, - }, - { - relu_operator_1.parallel_layer, - MappedOperatorTaskGroup{ - { - { - cpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord0}, - {TensorSlotName::OUTPUT, tensor_coord0}, - }}, - }, - { - cpu1, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord1}, - {TensorSlotName::OUTPUT, tensor_coord1}, - }}, - }, - }, - }, - }, - }); + }); return mpcg; } @@ -245,21 +247,20 @@ TEST_SUITE(FF_TEST_SUITE) { manager.start_controller([](RealmContext &ctx) { Allocator allocator = ctx.get_current_device_allocator(); - MappedParallelComputationGraph mpcg = make_test_mpcg_for_device_type(DeviceType::CPU); - + MappedParallelComputationGraph mpcg = + make_test_mpcg_for_device_type(DeviceType::CPU); std::unordered_map input_tensors; - OptimizerAttrs optimizer_attrs = - OptimizerAttrs{ - SGDOptimizerAttrs{ + OptimizerAttrs optimizer_attrs = OptimizerAttrs{ + SGDOptimizerAttrs{ /*lr=*/0.001, /*momentum=*/0.9, /*nesterov=*/false, /*weight_decay=*/0.001, - }, - }; + }, + }; DistributedFfHandle device_handle = create_distributed_ff_handle( ctx, @@ -303,17 +304,17 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { manager.start_controller([](RealmContext &ctx) { Allocator allocator = ctx.get_current_device_allocator(); - MappedParallelComputationGraph mpcg = make_test_mpcg_for_device_type(DeviceType::GPU); + MappedParallelComputationGraph mpcg = + make_test_mpcg_for_device_type(DeviceType::GPU); - OptimizerAttrs optimizer_attrs = - OptimizerAttrs{ - SGDOptimizerAttrs{ + OptimizerAttrs optimizer_attrs = OptimizerAttrs{ + SGDOptimizerAttrs{ /*lr=*/0.001, /*momentum=*/0.9, /*nesterov=*/false, /*weight_decay=*/0.001, - }, - }; + }, + }; std::unordered_map input_tensors; From 9d03c9766beaa67b2cbc1eea87976d2c44ae6153 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 21 May 2026 12:16:39 -0700 Subject: [PATCH 10/19] Refactor redop infrastructure and switch to Legion's redops. --- .../redops/realm_redop_registry.h | 16 + .../redops/redop_id_t.dtg.toml | 30 + .../realm-execution/redops/redop_id_t.h | 22 + .../realm-execution/tasks/realm_reduction.h | 154 ---- .../src/realm-execution/realm_manager.cc | 4 +- .../redops/realm_redop_registry.cc | 689 ++++++++++++++++++ .../src/realm-execution/redops/redop_id_t.cc | 32 + .../tasks/realm_task_registry.cc | 10 - 8 files changed, 792 insertions(+), 165 deletions(-) create mode 100644 lib/realm-execution/include/realm-execution/redops/realm_redop_registry.h create mode 100644 lib/realm-execution/include/realm-execution/redops/redop_id_t.dtg.toml create mode 100644 lib/realm-execution/include/realm-execution/redops/redop_id_t.h delete mode 100644 lib/realm-execution/include/realm-execution/tasks/realm_reduction.h create mode 100644 lib/realm-execution/src/realm-execution/redops/realm_redop_registry.cc create mode 100644 lib/realm-execution/src/realm-execution/redops/redop_id_t.cc diff --git a/lib/realm-execution/include/realm-execution/redops/realm_redop_registry.h b/lib/realm-execution/include/realm-execution/redops/realm_redop_registry.h new file mode 100644 index 0000000000..a338a38bbf --- /dev/null +++ b/lib/realm-execution/include/realm-execution/redops/realm_redop_registry.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_REDOP_REGISTRY_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_REDOP_REGISTRY_H + +#include "realm-execution/realm.h" +#include "realm-execution/redops/redop_id_t.dtg.h" + +namespace FlexFlow { + +/** + * \brief Registers all known reduction operators (redops). + */ +void Realm::Event register_all_redops(Realm::Runtime); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/redops/redop_id_t.dtg.toml b/lib/realm-execution/include/realm-execution/redops/redop_id_t.dtg.toml new file mode 100644 index 0000000000..5183ff5e72 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/redops/redop_id_t.dtg.toml @@ -0,0 +1,30 @@ +namespace = "FlexFlow" +name = "redop_id_t" +type = "enum" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] +docstring = ''' +\brief An enum for identifying reduction operators (redops) for use in the Realm runtime. +''' + +[[values]] +name = "SUM_BOOL_REDOP_ID" + +[[values]] +name = "SUM_INT32_REDOP_ID" + +[[values]] +name = "SUM_INT64_REDOP_ID" + +[[values]] +name = "SUM_HALF_REDOP_ID" + +[[values]] +name = "SUM_FLOAT_REDOP_ID" + +[[values]] +name = "SUM_DOUBLE_REDOP_ID" diff --git a/lib/realm-execution/include/realm-execution/redops/redop_id_t.h b/lib/realm-execution/include/realm-execution/redops/redop_id_t.h new file mode 100644 index 0000000000..b9ef91a05a --- /dev/null +++ b/lib/realm-execution/include/realm-execution/redops/redop_id_t.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_REDOP_REGISTRY_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_REDOP_REGISTRY_H + +#include "realm-execution/realm.h" +#include "realm-execution/redops/redop_id_t.dtg.h" + +namespace FlexFlow { + +/** + * \brief Registers all known reduction operators (redops). + */ +Realm::ReductionOpID get_sum_redop_id_for_data_type(DataType); + +/** + * \brief Convert a \ref FlexFlow::redop_id_t into a Realm reduction op ID. + */ +Realm::Processor::ReductionOpID + get_realm_reduction_op_id_for_redop_id(redop_id_t); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/tasks/realm_reduction.h b/lib/realm-execution/include/realm-execution/tasks/realm_reduction.h deleted file mode 100644 index 512e344824..0000000000 --- a/lib/realm-execution/include/realm-execution/tasks/realm_reduction.h +++ /dev/null @@ -1,154 +0,0 @@ -#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_REDUCTION_H -#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_REDUCTION_H -#include "op-attrs/datatype.dtg.h" -#include - -namespace FlexFlow { - -/** - * \brief Realm Sum Reduction for Float - * \see https://legion.stanford.edu/tutorial/realm/reductions.html - */ -struct SumReductionFloat { - using LHS = float; - using RHS = float; - - /** \brief Identity element for addition (0.0) */ - static constexpr RHS identity = 0.0f; - - /** - * \brief Apply reduction: lhs += rhs - * \tparam EXCLUSIVE If true, direct addition; if false, atomic CAS loop - * \param lhs Left-hand side accumulator (modified in place) - * \param rhs Value to add - */ - template - static void apply(LHS &lhs, RHS rhs) { - if (EXCLUSIVE) { - lhs += rhs; - } else { - // Atomic float add via CAS loop - union { - float f; - int i; - } old_val, new_val; - do { - old_val.f = lhs; - new_val.f = old_val.f + rhs; - } while ( - !__sync_bool_compare_and_swap((int *)&lhs, old_val.i, new_val.i)); - } - } - - /** - * \brief Fold two RHS values: rhs1 += rhs2 - * \tparam EXCLUSIVE If true, direct addition; if false, atomic CAS loop - * \param rhs1 Accumulator (modified in place) - * \param rhs2 Value to fold in - */ - template - static void fold(RHS &rhs1, RHS rhs2) { - if (EXCLUSIVE) { - rhs1 += rhs2; - } else { - // Atomic float add via CAS loop - union { - float f; - int i; - } old_val, new_val; - do { - old_val.f = rhs1; - new_val.f = old_val.f + rhs2; - } while ( - !__sync_bool_compare_and_swap((int *)&rhs1, old_val.i, new_val.i)); - } - } -}; - -/** - * \brief Realm Sum Reduction for Double - * \see https://legion.stanford.edu/tutorial/realm/reductions.html - */ -struct SumReductionDouble { - using LHS = double; - using RHS = double; - - /** \brief Identity element for addition (0.0) */ - static constexpr RHS identity = 0.0; - - /** - * \brief Apply reduction: lhs += rhs - * \tparam EXCLUSIVE If true, direct addition; if false, atomic CAS loop - * \param lhs Left-hand side accumulator (modified in place) - * \param rhs Value to add - */ - template - static void apply(LHS &lhs, RHS rhs) { - if (EXCLUSIVE) { - lhs += rhs; - } else { - // Atomic double add via CAS loop using long long reinterpretation - union { - double d; - long long i; - } old_val, new_val; - do { - old_val.d = lhs; - new_val.d = old_val.d + rhs; - } while (!__sync_bool_compare_and_swap( - (long long *)&lhs, old_val.i, new_val.i)); - } - } - - /** - * \brief Fold two RHS values: rhs1 += rhs2 - * \tparam EXCLUSIVE If true, direct addition; if false, atomic CAS loop - * \param rhs1 Accumulator (modified in place) - * \param rhs2 Value to fold in - */ - template - static void fold(RHS &rhs1, RHS rhs2) { - if (EXCLUSIVE) { - rhs1 += rhs2; - } else { - // Atomic double add via CAS loop using long long reinterpretation - union { - double d; - long long i; - } old_val, new_val; - do { - old_val.d = rhs1; - new_val.d = old_val.d + rhs2; - } while (!__sync_bool_compare_and_swap( - (long long *)&rhs1, old_val.i, new_val.i)); - } - } -}; - -/** - * \brief Reduction op IDs for sum reductions - * \warning These IDs must not conflict with other registered reduction ops - */ -enum SumReductionOpIDs { - REDOP_SUM_FLOAT = 1, ///< Sum reduction op ID for float - REDOP_SUM_DOUBLE = 2, ///< Sum reduction op ID for double -}; - -/** - * \brief Returns the Realm reduction op ID for a sum reduction over the given datatype - * \param dtype The datatype to look up - * \return The corresponding Realm::ReductionOpID - * \throws PANIC if no sum reduction is registered for the given datatype - */ -inline Realm::ReductionOpID get_sum_reduction_op_id(DataType dtype) { - switch (dtype) { - case DataType::FLOAT: - return REDOP_SUM_FLOAT; - case DataType::DOUBLE: - return REDOP_SUM_DOUBLE; - default: - PANIC("no sum reduction registered for datatype {}", dtype); - } -} -} // namespace FlexFlow -#endif diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc index e76be7054b..c7136d8a98 100644 --- a/lib/realm-execution/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -1,6 +1,7 @@ #include "realm-execution/realm_manager.h" #include "realm-execution/realm_context.h" #include "realm-execution/tasks/realm_task_registry.h" +#include "realm-execution/redops/realm_redop_registry.h" namespace FlexFlow { @@ -9,8 +10,9 @@ RealmManager::RealmManager(int *argc, char ***argv) bool ok = this->get_runtime().init(argc, argv); ASSERT(ok); - // Register all tasks at initialization time so we don't need to later + // Register all tasks and redops at initialization time so we don't need to later register_all_tasks().wait(); + register_all_redops(this->get_runtime()); } RealmManager::~RealmManager() { diff --git a/lib/realm-execution/src/realm-execution/redops/realm_redop_registry.cc b/lib/realm-execution/src/realm-execution/redops/realm_redop_registry.cc new file mode 100644 index 0000000000..d10b158463 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/redops/realm_redop_registry.cc @@ -0,0 +1,689 @@ +#include "realm-execution/redops/realm_redop_registry.h" + +namespace FlexFlow { + +// Reduction operators and related infrastructure borrowed from Legion. We +// maintain the Legion naming scheme to maximizing compatibility with the +// existing code, despite not otherwise relying or using Legion in any way. +// https://gitlab.com/StanfordLegion/legion/-/blob/5263aeff477fb94239c50d9306d58c4244e9fc38/runtime/legion/api/redop.inl#L31 +#if !defined(__cpp_lib_atomic_ref) || (__cpp_lib_atomic_ref < 201806L) +// We only need this crap if we're using a version of c++ < 20 +// Starting with c++20 we can do all this the right way with atomic_ref +namespace TypePunning { +// The tenth circle of hell is reserved for members of the C++ committee +// that decided to deviate from C's support for type punning unions. +// Add on to it the fact that it took them 9 fucking years to realize +// that they needed std::atomic_ref and it's plain to see they are all +// just a bunch of idiots that should never be allowed near a programming +// language standard ever again. They've clearly never written lock-free +// code in their lives. +template +class Pointer { +public: + Pointer(void *p) : pointer(convert(p)) {} + static inline T *convert(void *p) { + T *ptr = nullptr; + static_assert(sizeof(ptr) == sizeof(p)); + memcpy(&ptr, &p, sizeof(p)); + return ptr; + } + inline operator T *(void) const { + return (T *)pointer; + } + inline T operator*(void) const { + return *pointer; + } + inline T operator[](size_t off) const { + return pointer[off]; + } + +private: + T volatile *const pointer; +}; +template +class AlignedPointer { +public: + AlignedPointer(void *p) : off(align(p)), pointer(convert(p, off)) {} + static inline T *convert(void *p, size_t off) { + uint8_t *p1 = nullptr; + static_assert(sizeof(p1) == sizeof(p)); + memcpy(&p1, &p, sizeof(p)); + p1 = p1 - off; + T *p2 = nullptr; + static_assert(sizeof(p1) == sizeof(p2)); + memcpy(&p2, &p1, sizeof(p1)); + return p2; + } + static inline size_t align(void *p) { + uintptr_t ptr; + static_assert(sizeof(ptr) == sizeof(p)); + memcpy(&ptr, &p, sizeof(ptr)); + return ptr % ALIGNMENT; + } + inline operator T *(void) const { + return (T *)pointer; + } + inline T operator*(void) const { + return *pointer; + } + inline size_t offset(void) const { + return off; + } + +private: + size_t off; + T volatile *const pointer; +}; +template +class Alias { +public: + inline void load(Pointer const &pointer, size_t off = 0) { + T1 value = pointer[off]; + memcpy(buffer, (void *)&value, sizeof(T1)); + } + template + inline void load(AlignedPointer const &pointer) { + T1 value = *pointer; + memcpy(buffer, (void *)&value, sizeof(T1)); + } + inline T1 as_one(void) const { + T1 result; + memcpy((void *)&result, buffer, sizeof(result)); + return result; + } + inline T2 as_two(void) const { + T2 result; + memcpy((void *)&result, buffer, sizeof(result)); + return result; + } + inline Alias &operator=(T2 rhs) { + memcpy(buffer, (void *)&rhs, sizeof(rhs)); + return *this; + } + +private: + // Make this one private so it is can never be called + inline Alias &operator=(T1 rhs) { + memcpy(buffer, (void *)&rhs, sizeof(rhs)); + return *this; + } + static_assert(sizeof(T1) == sizeof(T2)); + uint8_t buffer[sizeof(T1)]; +}; +}; // namespace TypePunning +#endif + +// Define a prefix for annotating functions for CUDA compilation +#if defined(__CUDACC__) || defined(__HIPCC__) +#define __LEGION_CUDA_HD__ __host__ __device__ +#else +#define __LEGION_CUDA_HD__ +#endif + +template <> +class SumReduction { +public: + typedef bool LHS; + typedef bool RHS; + + static constexpr bool identity = false; + static constexpr int REDOP_ID = LEGION_REDOP_OR_BOOL; + + template + __LEGION_CUDA_HD__ static void apply(LHS &lhs, RHS rhs); + template + __LEGION_CUDA_HD__ static void fold(RHS &rhs1, RHS rhs2); +}; + +template <> +class SumReduction { +public: + typedef int32_t LHS; + typedef int32_t RHS; + + static constexpr int32_t identity = 0; + static constexpr int REDOP_ID = LEGION_REDOP_SUM_INT32; + + template + __LEGION_CUDA_HD__ static void apply(LHS &lhs, RHS rhs); + template + __LEGION_CUDA_HD__ static void fold(RHS &rhs1, RHS rhs2); +}; + +template <> +class SumReduction { +public: + typedef int64_t LHS; + typedef int64_t RHS; + + static constexpr int64_t identity = 0; + static constexpr int REDOP_ID = LEGION_REDOP_SUM_INT64; + + template + __LEGION_CUDA_HD__ static void apply(LHS &lhs, RHS rhs); + template + __LEGION_CUDA_HD__ static void fold(RHS &rhs1, RHS rhs2); +}; + +template <> +class SumReduction<__half> { +public: + typedef __half LHS; + typedef __half RHS; + + static inline const __half identity = __half(0, false /*raw*/); + static constexpr int REDOP_ID = LEGION_REDOP_SUM_FLOAT16; + + template + __LEGION_CUDA_HD__ static void apply(LHS &lhs, RHS rhs); + template + __LEGION_CUDA_HD__ static void fold(RHS &rhs1, RHS rhs2); +}; + +template <> +class SumReduction { +public: + typedef float LHS; + typedef float RHS; + + static constexpr float identity = 0.f; + static constexpr int REDOP_ID = LEGION_REDOP_SUM_FLOAT32; + + template + __LEGION_CUDA_HD__ static void apply(LHS &lhs, RHS rhs); + template + __LEGION_CUDA_HD__ static void fold(RHS &rhs1, RHS rhs2); +}; + +template <> +class SumReduction { +public: + typedef double LHS; + typedef double RHS; + + static constexpr double identity = 0.0; + static constexpr int REDOP_ID = LEGION_REDOP_SUM_FLOAT64; + + template + __LEGION_CUDA_HD__ static void apply(LHS &lhs, RHS rhs); + template + __LEGION_CUDA_HD__ static void fold(RHS &rhs1, RHS rhs2); +}; + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { + lhs = lhs || rhs; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + // GPU atomics need 4 byte alignment + const uintptr_t unaligned = reinterpret_cast(&lhs); + unsigned const offset = unaligned % sizeof(unsigned int); + const uintptr_t aligned = unaligned - offset; + unsigned int *ptr = reinterpret_cast(aligned); + unsigned int newval = *ptr, oldval; + do { + RHS previous = __uint2bool(newval, offset); + RHS next = previous || rhs; + oldval = newval; + newval = __bool2uint(newval, next, offset); + newval = atomicCAS(ptr, oldval, newval); + } while (oldval != newval); +#else +#if defined(__cpp_lib_atomic_ref) && (__cpp_lib_atomic_ref >= 201806L) + std::atomic_ref atomic(lhs); + RHS oldval = atomic.load(); + RHS newval; + do { + newval = oldval || rhs; + } while (!atomic.compare_exchange_weak(oldval, newval)); +#else + // No atomic logical operations so use compare and swap + TypePunning::Alias oldval, newval; + TypePunning::Pointer pointer((void *)&lhs); + do { + oldval.load(pointer); + newval = oldval.as_two() || rhs; + } while (!__sync_bool_compare_and_swap( + (int8_t *)pointer, oldval.as_one(), newval.as_one())); +#endif +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { + rhs1 = rhs1 || rhs2; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + // GPU atomics need 4 byte alignment + const uintptr_t unaligned = reinterpret_cast(&rhs1); + unsigned const offset = unaligned % sizeof(unsigned int); + const uintptr_t aligned = unaligned - offset; + unsigned int *ptr = reinterpret_cast(aligned); + unsigned int newval = *ptr, oldval; + do { + RHS previous = __uint2bool(newval, offset); + RHS next = previous || rhs2; + oldval = newval; + newval = __bool2uint(newval, next, offset); + newval = atomicCAS(ptr, oldval, newval); + } while (oldval != newval); +#else +#if defined(__cpp_lib_atomic_ref) && (__cpp_lib_atomic_ref >= 201806L) + std::atomic_ref atomic(rhs1); + RHS oldval = atomic.load(); + RHS newval; + do { + newval = oldval || rhs2; + } while (!atomic.compare_exchange_weak(oldval, newval)); +#else + // No atomic logical operations so use compare and swap + TypePunning::Alias oldval, newval; + TypePunning::Pointer pointer((void *)&rhs1); + do { + oldval.load(pointer); + newval = oldval.as_two() || rhs2; + } while (!__sync_bool_compare_and_swap( + (int8_t *)pointer, oldval.as_one(), newval.as_one())); +#endif +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { + lhs += rhs; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + atomicAdd(&lhs, rhs); +#else + __sync_fetch_and_add(&lhs, rhs); +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { + rhs1 += rhs2; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + atomicAdd(&rhs1, rhs2); +#else + __sync_fetch_and_add(&rhs1, rhs2); +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { + lhs += rhs; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + // Apparently there is no signed 64bit int atomic yet + RHS newval = lhs, oldval; + // Type punning like this is illegal in C++ but the + // CUDA manual has an example just like it so fuck it + unsigned long long int *ptr = (unsigned long long int *)&lhs; + do { + oldval = newval; + newval += rhs; + newval = __ulonglong_as_longlong(atomicCAS( + ptr, __longlong_as_ulonglong(oldval), __longlong_as_ulonglong(newval))); + } while (oldval != newval); +#else + __sync_fetch_and_add(&lhs, rhs); +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { + rhs1 += rhs2; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + // Apparently there is no signed 64bit int atomic yet + RHS newval = rhs1, oldval; + // Type punning like this is illegal in C++ but the + // CUDA manual has an example just like it so fuck it + unsigned long long int *ptr = (unsigned long long int *)&rhs1; + do { + oldval = newval; + newval += rhs2; + newval = __ulonglong_as_longlong(atomicCAS( + ptr, __longlong_as_ulonglong(oldval), __longlong_as_ulonglong(newval))); + } while (oldval != newval); +#else + __sync_fetch_and_add(&rhs1, rhs2); +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction<__half>::apply(LHS &lhs, + RHS rhs) { + lhs = lhs + rhs; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction<__half>::apply(LHS &lhs, + RHS rhs) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) +#if (__CUDA_ARCH__ >= 700) && (__CUDACC_VER_MAJOR__ >= 10) + atomicAdd(&lhs, rhs); +#else + // 16-bit atomics are not supported prior to volta + // 32-bit GPU atomics need 4 byte alignment + const uintptr_t unaligned = reinterpret_cast(&lhs); + unsigned const offset = unaligned % sizeof(unsigned int); + const uintptr_t aligned = unaligned - offset; + unsigned int *ptr = reinterpret_cast(aligned); + RHS newval = lhs, oldval, other; + if (offset == 0) { + other = *((&lhs) + 1); + do { + oldval = newval; + newval = newval + rhs; + unsigned int const result = atomicCAS( + ptr, __hilohalf2uint(other, oldval), __hilohalf2uint(other, newval)); + newval = __uint2lohalf(result); + other = __uint2hihalf(result); + } while (oldval != newval); + } else { + other = *((&lhs) - 1); + do { + oldval = newval; + newval = newval + rhs; + unsigned int const result = atomicCAS( + ptr, __hilohalf2uint(oldval, other), __hilohalf2uint(newval, other)); + other = __uint2lohalf(result); + newval = __uint2hihalf(result); + } while (oldval != newval); + } +#endif +#else +#if defined(__cpp_lib_atomic_ref) && (__cpp_lib_atomic_ref >= 201806L) + std::atomic_ref atomic(lhs); + RHS oldval = atomic.load(); + RHS newval; + do { + newval = oldval + rhs; + } while (!atomic.compare_exchange_weak(oldval, newval)); +#else + // No atomic floating point operations so use compare and swap + TypePunning::Alias> oldval, newval; + TypePunning::AlignedPointer pointer((void *)&lhs); + unsigned const offset = pointer.offset() / sizeof(__half); + do { + oldval.load(pointer); + std::array next = oldval.as_two(); + next[offset] = __convert_float_to_halfint( + __convert_halfint_to_float(next[offset]) + float(rhs)); + newval = next; + } while (!__sync_bool_compare_and_swap( + (int32_t *)pointer, oldval.as_one(), newval.as_one())); +#endif +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction<__half>::fold(RHS &rhs1, + RHS rhs2) { + rhs1 = rhs1 + rhs2; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction<__half>::fold(RHS &rhs1, + RHS rhs2) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) +#if (__CUDA_ARCH__ >= 700) && (__CUDACC_VER_MAJOR__ >= 10) + atomicAdd(&rhs1, rhs2); +#else + // 16-bit atomics are not supported prior to volta + // 32-bit GPU atomics need 4 byte alignment + const uintptr_t unaligned = reinterpret_cast(&rhs1); + unsigned const offset = unaligned % sizeof(unsigned int); + const uintptr_t aligned = unaligned - offset; + unsigned int *ptr = reinterpret_cast(aligned); + RHS newval = rhs1, oldval, other; + if (offset == 0) { + other = *((&rhs1) + 1); + do { + oldval = newval; + newval = newval + rhs2; + unsigned int const result = atomicCAS( + ptr, __hilohalf2uint(other, oldval), __hilohalf2uint(other, newval)); + newval = __uint2lohalf(result); + other = __uint2hihalf(result); + } while (oldval != newval); + } else { + other = *((&rhs1) - 1); + do { + oldval = newval; + newval = newval + rhs2; + unsigned int const result = atomicCAS( + ptr, __hilohalf2uint(oldval, other), __hilohalf2uint(newval, other)); + other = __uint2lohalf(result); + newval = __uint2hihalf(result); + } while (oldval != newval); + } +#endif +#else +#if defined(__cpp_lib_atomic_ref) && (__cpp_lib_atomic_ref >= 201806L) + std::atomic_ref atomic(rhs1); + RHS oldval = atomic.load(); + RHS newval; + do { + newval = oldval + rhs2; + } while (!atomic.compare_exchange_weak(oldval, newval)); +#else + // No atomic floating point operations so use compare and swap + TypePunning::Alias> oldval, newval; + TypePunning::AlignedPointer pointer((void *)&rhs1); + unsigned const offset = pointer.offset() / sizeof(__half); + do { + oldval.load(pointer); + std::array next = oldval.as_two(); + next[offset] = __convert_float_to_halfint( + __convert_halfint_to_float(next[offset]) + float(rhs2)); + newval = next; + } while (!__sync_bool_compare_and_swap( + (int32_t *)pointer, oldval.as_one(), newval.as_one())); +#endif +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { + lhs += rhs; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + atomicAdd(&lhs, rhs); +#else +#if defined(__cpp_lib_atomic_ref) && (__cpp_lib_atomic_ref >= 201806L) + std::atomic_ref atomic(lhs); + RHS oldval = atomic.load(); + RHS newval; + do { + newval = oldval + rhs; + } while (!atomic.compare_exchange_weak(oldval, newval)); +#else + // No atomic floating point operations so use compare and swap + TypePunning::Alias oldval, newval; + TypePunning::Pointer pointer((void *)&lhs); + do { + oldval.load(pointer); + newval = oldval.as_two() + rhs; + } while (!__sync_bool_compare_and_swap( + (int32_t *)pointer, oldval.as_one(), newval.as_one())); +#endif +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { + rhs1 += rhs2; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + atomicAdd(&rhs1, rhs2); +#else +#if defined(__cpp_lib_atomic_ref) && (__cpp_lib_atomic_ref >= 201806L) + std::atomic_ref atomic(rhs1); + RHS oldval = atomic.load(); + RHS newval; + do { + newval = oldval + rhs2; + } while (!atomic.compare_exchange_weak(oldval, newval)); +#else + // No atomic floating point operations so use compare and swap + TypePunning::Alias oldval, newval; + TypePunning::Pointer pointer((void *)&rhs1); + do { + oldval.load(pointer); + newval = oldval.as_two() + rhs2; + } while (!__sync_bool_compare_and_swap( + (int32_t *)pointer, oldval.as_one(), newval.as_one())); +#endif +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { + lhs += rhs; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) +#if (__CUDA_ARCH__ >= 600) || defined(__HIP_DEVICE_COMPILE__) + atomicAdd(&lhs, rhs); +#else + RHS newval = lhs, oldval; + // Type punning like this is illegal in C++ but the + // CUDA manual has an example just like it so fuck it + unsigned long long int *ptr = (unsigned long long int *)&lhs; + do { + oldval = newval; + newval += rhs; + newval = __ulonglong_as_double(atomicCAS( + ptr, __double_as_ulonglong(oldval), __double_as_ulonglong(newval))); + } while (oldval != newval); +#endif +#else +#if defined(__cpp_lib_atomic_ref) && (__cpp_lib_atomic_ref >= 201806L) + std::atomic_ref atomic(lhs); + RHS oldval = atomic.load(); + RHS newval; + do { + newval = oldval + rhs; + } while (!atomic.compare_exchange_weak(oldval, newval)); +#else + // No atomic floating point operations so use compare and swap + TypePunning::Alias oldval, newval; + TypePunning::Pointer pointer((void *)&lhs); + do { + oldval.load(pointer); + newval = oldval.as_two() + rhs; + } while (!__sync_bool_compare_and_swap( + (int64_t *)pointer, oldval.as_one(), newval.as_one())); +#endif +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { + rhs1 += rhs2; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) +#if (__CUDA_ARCH__ >= 600) || defined(__HIP_DEVICE_COMPILE__) + atomicAdd(&rhs1, rhs2); +#else + RHS newval = rhs1, oldval; + // Type punning like this is illegal in C++ but the + // CUDA manual has an example just like it so fuck it + unsigned long long int *ptr = (unsigned long long int *)&rhs1; + do { + oldval = newval; + newval += rhs2; + newval = __ulonglong_as_double(atomicCAS( + ptr, __double_as_ulonglong(oldval), __double_as_ulonglong(newval))); + } while (oldval != newval); +#endif +#else +#if defined(__cpp_lib_atomic_ref) && (__cpp_lib_atomic_ref >= 201806L) + std::atomic_ref atomic(rhs1); + RHS oldval = atomic.load(); + RHS newval; + do { + newval = oldval + rhs2; + } while (!atomic.compare_exchange_weak(oldval, newval)); +#else + // No atomic floating point operations so use compare and swap + TypePunning::Alias oldval, newval; + TypePunning::Pointer pointer((void *)&rhs1); + do { + oldval.load(pointer); + newval = oldval.as_two() + rhs2; + } while (!__sync_bool_compare_and_swap( + (int64_t *)pointer, oldval.as_one(), newval.as_one())); +#endif +#endif +} + +void Realm::Event register_all_redops(Realm::Runtime rt) { + // Registration is synchronous, so no need to capture events here + rt.register_reduction>( + get_realm_reduction_op_id_for_redop_id(redop_id_t::SUM_BOOL_REDOP_ID)); + rt.register_reduction>( + get_realm_reduction_op_id_for_redop_id(redop_id_t::SUM_INT32_REDOP_ID)); + rt.register_reduction>( + get_realm_reduction_op_id_for_redop_id(redop_id_t::SUM_INT64_REDOP_ID)); + rt.register_reduction>( + get_realm_reduction_op_id_for_redop_id(redop_id_t::SUM_HALF_REDOP_ID)); + rt.register_reduction>( + get_realm_reduction_op_id_for_redop_id(redop_id_t::SUM_FLOAT_REDOP_ID)); + rt.register_reduction>( + get_realm_reduction_op_id_for_redop_id(redop_id_t::SUM_DOUBLE_REDOP_ID)); +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/redops/redop_id_t.cc b/lib/realm-execution/src/realm-execution/redops/redop_id_t.cc new file mode 100644 index 0000000000..702ddd5e97 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/redops/redop_id_t.cc @@ -0,0 +1,32 @@ +#include "realm-execution/redops/redop_id_t.h" + +namespace FlexFlow { + +Realm::ReductionOpID get_sum_redop_id_for_data_type(DataType) { + + switch (dtype) { + case DataType::BOOL: + return redop_id_t::SUM_BOOL_REDOP_ID; + case DataType::INT32: + return redop_id_t::SUM_INT32_REDOP_ID; + case DataType::INT64: + return redop_id_t::SUM_INT64_REDOP_ID; + case DataType::HALF: + return redop_id_t::SUM_HALF_REDOP_ID; + case DataType::FLOAT: + return redop_id_t::SUM_FLOAT_REDOP_ID; + case DataType::DOUBLE: + return redop_id_t::SUM_DOUBLE_REDOP_ID; + default: + PANIC("No known sum reduction for data type {}", dtype); + } +} + +Realm::Processor::ReductionOpID + get_realm_reduction_op_id_for_redop_id(redop_id_t redop_id) { + return static_cast(redop_id); +} + +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc index acafdf59fd..e7a8948f8d 100644 --- a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc +++ b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc @@ -5,7 +5,6 @@ #include "realm-execution/tasks/impl/op_task.h" #include "realm-execution/tasks/impl/per_device_op_state_init_return_task.h" #include "realm-execution/tasks/impl/per_device_op_state_init_task.h" -#include "realm-execution/tasks/realm_reduction.h" #include "realm-execution/tasks/task_id_t.h" #include "utils/exception.h" @@ -31,18 +30,9 @@ Realm::Event register_task(Realm::Processor::Kind target_kind, Realm::ProfilingRequestSet()); } -static void register_reductions() { - // register sum reduction ops - Realm::Runtime rt = Realm::Runtime::get_runtime(); - rt.register_reduction(REDOP_SUM_FLOAT); - rt.register_reduction(REDOP_SUM_DOUBLE); - // register_reduction is synchronous — no event returned -} - Realm::Event register_all_tasks() { std::vector pending_registrations; - register_reductions(); std::vector init_task_ids = { // Init tasks task_id_t::BATCHNORM_INIT_TASK_ID, From 0d589f2056022d1068deb2d6abf745e3dcc048cd Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 21 May 2026 12:39:41 -0700 Subject: [PATCH 11/19] Fix build for reductions. --- .../redops/realm_redop_registry.h | 6 +- .../redops/redop_id_t.dtg.toml | 3 - .../realm-execution/redops/redop_id_t.h | 12 +- .../src/realm-execution/pcg_instance.cc | 7 +- .../src/realm-execution/realm_manager.cc | 2 +- .../redops/realm_redop_registry.cc | 165 +----------------- .../src/realm-execution/redops/redop_id_t.cc | 12 +- 7 files changed, 26 insertions(+), 181 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/redops/realm_redop_registry.h b/lib/realm-execution/include/realm-execution/redops/realm_redop_registry.h index a338a38bbf..e7e51326e1 100644 --- a/lib/realm-execution/include/realm-execution/redops/realm_redop_registry.h +++ b/lib/realm-execution/include/realm-execution/redops/realm_redop_registry.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_REDOP_REGISTRY_H -#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_REDOP_REGISTRY_H +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REDOPS_REALM_REDOP_REGISTRY_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REDOPS_REALM_REDOP_REGISTRY_H #include "realm-execution/realm.h" #include "realm-execution/redops/redop_id_t.dtg.h" @@ -9,7 +9,7 @@ namespace FlexFlow { /** * \brief Registers all known reduction operators (redops). */ -void Realm::Event register_all_redops(Realm::Runtime); +void register_all_redops(Realm::Runtime); } // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/redops/redop_id_t.dtg.toml b/lib/realm-execution/include/realm-execution/redops/redop_id_t.dtg.toml index 5183ff5e72..44e1f32c59 100644 --- a/lib/realm-execution/include/realm-execution/redops/redop_id_t.dtg.toml +++ b/lib/realm-execution/include/realm-execution/redops/redop_id_t.dtg.toml @@ -20,9 +20,6 @@ name = "SUM_INT32_REDOP_ID" [[values]] name = "SUM_INT64_REDOP_ID" -[[values]] -name = "SUM_HALF_REDOP_ID" - [[values]] name = "SUM_FLOAT_REDOP_ID" diff --git a/lib/realm-execution/include/realm-execution/redops/redop_id_t.h b/lib/realm-execution/include/realm-execution/redops/redop_id_t.h index b9ef91a05a..8565b20b17 100644 --- a/lib/realm-execution/include/realm-execution/redops/redop_id_t.h +++ b/lib/realm-execution/include/realm-execution/redops/redop_id_t.h @@ -1,21 +1,21 @@ -#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_REDOP_REGISTRY_H -#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_REDOP_REGISTRY_H +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REDOPS_REALM_REDOP_ID_T_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REDOPS_REALM_REDOP_ID_T_H +#include "op-attrs/datatype.dtg.h" #include "realm-execution/realm.h" #include "realm-execution/redops/redop_id_t.dtg.h" namespace FlexFlow { /** - * \brief Registers all known reduction operators (redops). + * \brief Return the sum reduction operator (redop) ID for a given data type. */ -Realm::ReductionOpID get_sum_redop_id_for_data_type(DataType); +redop_id_t get_sum_redop_id_for_data_type(DataType); /** * \brief Convert a \ref FlexFlow::redop_id_t into a Realm reduction op ID. */ -Realm::Processor::ReductionOpID - get_realm_reduction_op_id_for_redop_id(redop_id_t); +Realm::ReductionOpID get_realm_reduction_op_id_for_redop_id(redop_id_t); } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance.cc index 332669a9dc..1ac3821142 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance.cc @@ -5,8 +5,8 @@ #include "realm-execution/distributed_per_device_op_state_initialization.h" #include "realm-execution/instance_allocation.h" #include "realm-execution/realm_context.h" +#include "realm-execution/redops/redop_id_t.h" #include "realm-execution/tasks/impl/op_task.h" -#include "realm-execution/tasks/realm_reduction.h" #include "realm-execution/tensor_instance_backing.h" #include "task-spec/dynamic_graph/copy_insertion.h" #include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" @@ -228,8 +228,9 @@ static Realm::Event spawn_dynamic_node_invocation( Realm::RegionInstance dst_inst = tensor_instance_backing.backing.at(input_grad).first; - Realm::ReductionOpID redop_id = get_sum_reduction_op_id( - assert_unwrap(output_grad.parallel_tensor_shape).data_type); + Realm::ReductionOpID redop_id = + get_realm_reduction_op_id_for_redop_id(get_sum_redop_id_for_data_type( + assert_unwrap(output_grad.parallel_tensor_shape).data_type)); // chain reductions sequentially to avoid write races on dst Realm::Event e = precondition; diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc index c7136d8a98..5a8f9cbbbb 100644 --- a/lib/realm-execution/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -1,7 +1,7 @@ #include "realm-execution/realm_manager.h" #include "realm-execution/realm_context.h" -#include "realm-execution/tasks/realm_task_registry.h" #include "realm-execution/redops/realm_redop_registry.h" +#include "realm-execution/tasks/realm_task_registry.h" namespace FlexFlow { diff --git a/lib/realm-execution/src/realm-execution/redops/realm_redop_registry.cc b/lib/realm-execution/src/realm-execution/redops/realm_redop_registry.cc index d10b158463..ab3304836a 100644 --- a/lib/realm-execution/src/realm-execution/redops/realm_redop_registry.cc +++ b/lib/realm-execution/src/realm-execution/redops/realm_redop_registry.cc @@ -1,4 +1,5 @@ #include "realm-execution/redops/realm_redop_registry.h" +#include "realm-execution/redops/redop_id_t.h" namespace FlexFlow { @@ -120,6 +121,12 @@ class Alias { #define __LEGION_CUDA_HD__ #endif +template +class SumReduction { + // Empty definition + // Specializations provided for each type +}; + template <> class SumReduction { public: @@ -127,7 +134,6 @@ class SumReduction { typedef bool RHS; static constexpr bool identity = false; - static constexpr int REDOP_ID = LEGION_REDOP_OR_BOOL; template __LEGION_CUDA_HD__ static void apply(LHS &lhs, RHS rhs); @@ -142,7 +148,6 @@ class SumReduction { typedef int32_t RHS; static constexpr int32_t identity = 0; - static constexpr int REDOP_ID = LEGION_REDOP_SUM_INT32; template __LEGION_CUDA_HD__ static void apply(LHS &lhs, RHS rhs); @@ -157,22 +162,6 @@ class SumReduction { typedef int64_t RHS; static constexpr int64_t identity = 0; - static constexpr int REDOP_ID = LEGION_REDOP_SUM_INT64; - - template - __LEGION_CUDA_HD__ static void apply(LHS &lhs, RHS rhs); - template - __LEGION_CUDA_HD__ static void fold(RHS &rhs1, RHS rhs2); -}; - -template <> -class SumReduction<__half> { -public: - typedef __half LHS; - typedef __half RHS; - - static inline const __half identity = __half(0, false /*raw*/); - static constexpr int REDOP_ID = LEGION_REDOP_SUM_FLOAT16; template __LEGION_CUDA_HD__ static void apply(LHS &lhs, RHS rhs); @@ -187,7 +176,6 @@ class SumReduction { typedef float RHS; static constexpr float identity = 0.f; - static constexpr int REDOP_ID = LEGION_REDOP_SUM_FLOAT32; template __LEGION_CUDA_HD__ static void apply(LHS &lhs, RHS rhs); @@ -202,7 +190,6 @@ class SumReduction { typedef double RHS; static constexpr double identity = 0.0; - static constexpr int REDOP_ID = LEGION_REDOP_SUM_FLOAT64; template __LEGION_CUDA_HD__ static void apply(LHS &lhs, RHS rhs); @@ -382,140 +369,6 @@ __LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, #endif } -template <> -__LEGION_CUDA_HD__ inline void SumReduction<__half>::apply(LHS &lhs, - RHS rhs) { - lhs = lhs + rhs; -} - -template <> -__LEGION_CUDA_HD__ inline void SumReduction<__half>::apply(LHS &lhs, - RHS rhs) { -#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) -#if (__CUDA_ARCH__ >= 700) && (__CUDACC_VER_MAJOR__ >= 10) - atomicAdd(&lhs, rhs); -#else - // 16-bit atomics are not supported prior to volta - // 32-bit GPU atomics need 4 byte alignment - const uintptr_t unaligned = reinterpret_cast(&lhs); - unsigned const offset = unaligned % sizeof(unsigned int); - const uintptr_t aligned = unaligned - offset; - unsigned int *ptr = reinterpret_cast(aligned); - RHS newval = lhs, oldval, other; - if (offset == 0) { - other = *((&lhs) + 1); - do { - oldval = newval; - newval = newval + rhs; - unsigned int const result = atomicCAS( - ptr, __hilohalf2uint(other, oldval), __hilohalf2uint(other, newval)); - newval = __uint2lohalf(result); - other = __uint2hihalf(result); - } while (oldval != newval); - } else { - other = *((&lhs) - 1); - do { - oldval = newval; - newval = newval + rhs; - unsigned int const result = atomicCAS( - ptr, __hilohalf2uint(oldval, other), __hilohalf2uint(newval, other)); - other = __uint2lohalf(result); - newval = __uint2hihalf(result); - } while (oldval != newval); - } -#endif -#else -#if defined(__cpp_lib_atomic_ref) && (__cpp_lib_atomic_ref >= 201806L) - std::atomic_ref atomic(lhs); - RHS oldval = atomic.load(); - RHS newval; - do { - newval = oldval + rhs; - } while (!atomic.compare_exchange_weak(oldval, newval)); -#else - // No atomic floating point operations so use compare and swap - TypePunning::Alias> oldval, newval; - TypePunning::AlignedPointer pointer((void *)&lhs); - unsigned const offset = pointer.offset() / sizeof(__half); - do { - oldval.load(pointer); - std::array next = oldval.as_two(); - next[offset] = __convert_float_to_halfint( - __convert_halfint_to_float(next[offset]) + float(rhs)); - newval = next; - } while (!__sync_bool_compare_and_swap( - (int32_t *)pointer, oldval.as_one(), newval.as_one())); -#endif -#endif -} - -template <> -__LEGION_CUDA_HD__ inline void SumReduction<__half>::fold(RHS &rhs1, - RHS rhs2) { - rhs1 = rhs1 + rhs2; -} - -template <> -__LEGION_CUDA_HD__ inline void SumReduction<__half>::fold(RHS &rhs1, - RHS rhs2) { -#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) -#if (__CUDA_ARCH__ >= 700) && (__CUDACC_VER_MAJOR__ >= 10) - atomicAdd(&rhs1, rhs2); -#else - // 16-bit atomics are not supported prior to volta - // 32-bit GPU atomics need 4 byte alignment - const uintptr_t unaligned = reinterpret_cast(&rhs1); - unsigned const offset = unaligned % sizeof(unsigned int); - const uintptr_t aligned = unaligned - offset; - unsigned int *ptr = reinterpret_cast(aligned); - RHS newval = rhs1, oldval, other; - if (offset == 0) { - other = *((&rhs1) + 1); - do { - oldval = newval; - newval = newval + rhs2; - unsigned int const result = atomicCAS( - ptr, __hilohalf2uint(other, oldval), __hilohalf2uint(other, newval)); - newval = __uint2lohalf(result); - other = __uint2hihalf(result); - } while (oldval != newval); - } else { - other = *((&rhs1) - 1); - do { - oldval = newval; - newval = newval + rhs2; - unsigned int const result = atomicCAS( - ptr, __hilohalf2uint(oldval, other), __hilohalf2uint(newval, other)); - other = __uint2lohalf(result); - newval = __uint2hihalf(result); - } while (oldval != newval); - } -#endif -#else -#if defined(__cpp_lib_atomic_ref) && (__cpp_lib_atomic_ref >= 201806L) - std::atomic_ref atomic(rhs1); - RHS oldval = atomic.load(); - RHS newval; - do { - newval = oldval + rhs2; - } while (!atomic.compare_exchange_weak(oldval, newval)); -#else - // No atomic floating point operations so use compare and swap - TypePunning::Alias> oldval, newval; - TypePunning::AlignedPointer pointer((void *)&rhs1); - unsigned const offset = pointer.offset() / sizeof(__half); - do { - oldval.load(pointer); - std::array next = oldval.as_two(); - next[offset] = __convert_float_to_halfint( - __convert_halfint_to_float(next[offset]) + float(rhs2)); - newval = next; - } while (!__sync_bool_compare_and_swap( - (int32_t *)pointer, oldval.as_one(), newval.as_one())); -#endif -#endif -} - template <> __LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, RHS rhs) { @@ -670,7 +523,7 @@ __LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, #endif } -void Realm::Event register_all_redops(Realm::Runtime rt) { +void register_all_redops(Realm::Runtime rt) { // Registration is synchronous, so no need to capture events here rt.register_reduction>( get_realm_reduction_op_id_for_redop_id(redop_id_t::SUM_BOOL_REDOP_ID)); @@ -678,8 +531,6 @@ void Realm::Event register_all_redops(Realm::Runtime rt) { get_realm_reduction_op_id_for_redop_id(redop_id_t::SUM_INT32_REDOP_ID)); rt.register_reduction>( get_realm_reduction_op_id_for_redop_id(redop_id_t::SUM_INT64_REDOP_ID)); - rt.register_reduction>( - get_realm_reduction_op_id_for_redop_id(redop_id_t::SUM_HALF_REDOP_ID)); rt.register_reduction>( get_realm_reduction_op_id_for_redop_id(redop_id_t::SUM_FLOAT_REDOP_ID)); rt.register_reduction>( diff --git a/lib/realm-execution/src/realm-execution/redops/redop_id_t.cc b/lib/realm-execution/src/realm-execution/redops/redop_id_t.cc index 702ddd5e97..f31769419f 100644 --- a/lib/realm-execution/src/realm-execution/redops/redop_id_t.cc +++ b/lib/realm-execution/src/realm-execution/redops/redop_id_t.cc @@ -1,9 +1,9 @@ #include "realm-execution/redops/redop_id_t.h" +#include "utils/exception.h" namespace FlexFlow { -Realm::ReductionOpID get_sum_redop_id_for_data_type(DataType) { - +redop_id_t get_sum_redop_id_for_data_type(DataType dtype) { switch (dtype) { case DataType::BOOL: return redop_id_t::SUM_BOOL_REDOP_ID; @@ -11,8 +11,6 @@ Realm::ReductionOpID get_sum_redop_id_for_data_type(DataType) { return redop_id_t::SUM_INT32_REDOP_ID; case DataType::INT64: return redop_id_t::SUM_INT64_REDOP_ID; - case DataType::HALF: - return redop_id_t::SUM_HALF_REDOP_ID; case DataType::FLOAT: return redop_id_t::SUM_FLOAT_REDOP_ID; case DataType::DOUBLE: @@ -22,11 +20,9 @@ Realm::ReductionOpID get_sum_redop_id_for_data_type(DataType) { } } -Realm::Processor::ReductionOpID +Realm::ReductionOpID get_realm_reduction_op_id_for_redop_id(redop_id_t redop_id) { - return static_cast(redop_id); -} - + return static_cast(redop_id); } } // namespace FlexFlow From 5ef0b070ad9cf21cbae6346e8294074fb81e74ad Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 21 May 2026 14:45:34 -0700 Subject: [PATCH 12/19] Split reduction from copy and put back device op state init code. --- .../include/realm-execution/realm_context.h | 29 ++-- ...uted_per_device_op_state_initialization.cc | 6 +- .../src/realm-execution/pcg_instance.cc | 19 ++- .../src/realm-execution/realm_context.cc | 127 ++++++++++++------ 4 files changed, 112 insertions(+), 69 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/realm_context.h b/lib/realm-execution/include/realm-execution/realm_context.h index eab42d0d79..5b76d52e2c 100644 --- a/lib/realm-execution/include/realm-execution/realm_context.h +++ b/lib/realm-execution/include/realm-execution/realm_context.h @@ -9,6 +9,7 @@ #include "pcg/device_id_t.dtg.h" #include "pcg/machine_space_coordinate.dtg.h" #include "realm-execution/realm.h" +#include "realm-execution/redops/redop_id_t.dtg.h" #include "realm-execution/tasks/task_id_t.dtg.h" #include #include @@ -65,16 +66,24 @@ struct RealmContext { /** \name Data movement and reduction */ ///\{ - Realm::Event - issue_copy(ParallelTensorShape const &src_shape, - Realm::RegionInstance src_inst, - ParallelTensorShape const &dst_shape, - Realm::RegionInstance dst_inst, - Realm::ProfilingRequestSet const &requests, - Realm::Event wait_on = Realm::Event::NO_EVENT, - int priority = 0, - std::optional redop_id = std::nullopt, - bool exclusive = false); + Realm::Event issue_copy(ParallelTensorShape const &src_shape, + Realm::RegionInstance src_inst, + ParallelTensorShape const &dst_shape, + Realm::RegionInstance dst_inst, + Realm::ProfilingRequestSet const &requests, + Realm::Event wait_on = Realm::Event::NO_EVENT, + int priority = 0); + + Realm::Event issue_reduction(ParallelTensorShape const &src_shape, + Realm::RegionInstance src_inst, + ParallelTensorShape const &dst_shape, + Realm::RegionInstance dst_inst, + redop_id_t redop_id, + bool is_fold, + bool exclusive, + Realm::ProfilingRequestSet const &requests, + Realm::Event wait_on = Realm::Event::NO_EVENT, + int priority = 0); ///\} /** \name Instance management */ diff --git a/lib/realm-execution/src/realm-execution/distributed_per_device_op_state_initialization.cc b/lib/realm-execution/src/realm-execution/distributed_per_device_op_state_initialization.cc index e7d8647b12..1d517a8fe4 100644 --- a/lib/realm-execution/src/realm-execution/distributed_per_device_op_state_initialization.cc +++ b/lib/realm-execution/src/realm-execution/distributed_per_device_op_state_initialization.cc @@ -31,7 +31,6 @@ PerDeviceOpStateBacking perform_distributed_per_device_op_state_initialization( std::unordered_map *> device_state_map; - std::vector completion_events; for (DynamicNodeInvocation const &invocation : dg.invocations) { Realm::Processor target_proc = ctx.map_device_coord_to_processor( assert_unwrap(invocation.node_attrs.device_coord)); @@ -57,7 +56,6 @@ PerDeviceOpStateBacking perform_distributed_per_device_op_state_initialization( precondition); if (completion_event.has_value()) { - completion_events.push_back(completion_event.value()); device_state_map.insert(std::pair{invocation, device_state_ptr}); } else { // Task doesn't require initialization, clean up and don't store result @@ -65,9 +63,7 @@ PerDeviceOpStateBacking perform_distributed_per_device_op_state_initialization( } } - // wait for all init tasks — direct write to *result_ptr happens - // before each init task event fires so result is ready after this - Realm::Event::merge_events(completion_events).wait(); + ctx.get_outstanding_events().wait(); auto deref = [](DeviceSpecificPtr *const &p) { return *p; }; std::unordered_map> diff --git a/lib/realm-execution/src/realm-execution/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance.cc index 1ac3821142..aa67110127 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance.cc @@ -228,12 +228,11 @@ static Realm::Event spawn_dynamic_node_invocation( Realm::RegionInstance dst_inst = tensor_instance_backing.backing.at(input_grad).first; - Realm::ReductionOpID redop_id = - get_realm_reduction_op_id_for_redop_id(get_sum_redop_id_for_data_type( - assert_unwrap(output_grad.parallel_tensor_shape).data_type)); + redop_id_t redop_id = get_sum_redop_id_for_data_type( + assert_unwrap(output_grad.parallel_tensor_shape).data_type); // chain reductions sequentially to avoid write races on dst - Realm::Event e = precondition; + Realm::Event result = precondition; for (auto const &[p, m] : assert_unwrap(output_grad.mapping)) { DynamicValueAttrs replica_key = output_grad; replica_key.mapping = @@ -243,18 +242,18 @@ static Realm::Event spawn_dynamic_node_invocation( Realm::RegionInstance src_inst = tensor_instance_backing.backing.at(replica_key).first; - e = ctx.issue_copy( + result = ctx.issue_reduction( /*src_shape=*/assert_unwrap(output_grad.parallel_tensor_shape), /*src_inst=*/src_inst, /*dst_shape=*/assert_unwrap(input_grad.parallel_tensor_shape), /*dst_inst=*/dst_inst, - /*requests=*/Realm::ProfilingRequestSet{}, - /*wait_on=*/e, - /*priority=*/0, /*redop_id=*/redop_id, - /*exlusive=*/false); + /*is_fold=*/false, + /*exlusive=*/false, + /*requests=*/Realm::ProfilingRequestSet{}, + /*wait_on=*/result); } - return e; + return result; }; TrainingOperationAttrs op_attrs = diff --git a/lib/realm-execution/src/realm-execution/realm_context.cc b/lib/realm-execution/src/realm-execution/realm_context.cc index a4669bf43e..36dd7c71cc 100644 --- a/lib/realm-execution/src/realm-execution/realm_context.cc +++ b/lib/realm-execution/src/realm-execution/realm_context.cc @@ -7,6 +7,7 @@ #include "pcg/device_id_t.h" #include "pcg/device_type.dtg.h" #include "realm-execution/realm_allocator.h" +#include "realm-execution/redops/redop_id_t.h" #include "realm-execution/tasks/task_id_t.dtg.h" #include "realm-execution/tasks/task_id_t.h" #include "utils/containers/contains_key.h" @@ -154,6 +155,46 @@ static Realm::IndexSpace ispace_from_dims(TensorDims const &dims) { return Realm::IndexSpace{rect}; } +[[nodiscard]] static Realm::Event + issue_copy_for_field(TensorDims const &dims, + Realm::CopySrcDstField const &src_field, + Realm::CopySrcDstField const &dst_field, + Realm::ProfilingRequestSet const &requests, + Realm::Event wait_on, + int priority) { + switch (dims.ff_ordered.num_dims()) { +#if REALM_MAX_DIM >= 1 + case 1: + return ispace_from_dims<1>(dims).copy( + {src_field}, {dst_field}, requests, wait_on, priority); +#endif +#if REALM_MAX_DIM >= 2 + case 2: + return ispace_from_dims<2>(dims).copy( + {src_field}, {dst_field}, requests, wait_on, priority); +#endif +#if REALM_MAX_DIM >= 3 + case 3: + return ispace_from_dims<3>(dims).copy( + {src_field}, {dst_field}, requests, wait_on, priority); +#endif +#if REALM_MAX_DIM >= 4 + case 4: + return ispace_from_dims<4>(dims).copy( + {src_field}, {dst_field}, requests, wait_on, priority); +#endif +#if REALM_MAX_DIM >= 5 + case 5: + return ispace_from_dims<5>(dims).copy( + {src_field}, {dst_field}, requests, wait_on, priority); +#endif + default: + PANIC("TensorShape dims greater than REALM_MAX_DIM: {}", + dims.ff_ordered.num_dims()); + break; + } +} + Realm::Event RealmContext::issue_copy(ParallelTensorShape const &src_shape, Realm::RegionInstance src_inst, @@ -161,9 +202,7 @@ Realm::Event Realm::RegionInstance dst_inst, Realm::ProfilingRequestSet const &requests, Realm::Event wait_on, - int priority, - std::optional redop_id, - bool exclusive) { + int priority) { TensorShape src_piece_shape = get_piece_shape(src_shape); TensorShape dst_piece_shape = get_piece_shape(dst_shape); ASSERT(src_piece_shape == dst_piece_shape); // For now, assume they match @@ -185,48 +224,48 @@ Realm::Event size_of_datatype(src_piece_shape.data_type).int_from_positive_int()), /*subfield_offset=*/0); - // set reduction op on dst field if provided - if (redop_id.has_value()) { - dst_field.set_redop(redop_id.value(), /*is_fold=*/false, exclusive); - } + Realm::Event result = issue_copy_for_field( + src_piece_shape.dims, src_field, dst_field, requests, wait_on, priority); + this->outstanding_events.push_back(result); + return result; +} - Realm::Event result; - switch (src_piece_shape.dims.ff_ordered.num_dims()) { -#if REALM_MAX_DIM >= 1 - case 1: - result = ispace_from_dims<1>(src_piece_shape.dims) - .copy({src_field}, {dst_field}, requests, wait_on, priority); - break; -#endif -#if REALM_MAX_DIM >= 2 - case 2: - result = ispace_from_dims<2>(src_piece_shape.dims) - .copy({src_field}, {dst_field}, requests, wait_on, priority); - break; -#endif -#if REALM_MAX_DIM >= 3 - case 3: - result = ispace_from_dims<3>(src_piece_shape.dims) - .copy({src_field}, {dst_field}, requests, wait_on, priority); - break; -#endif -#if REALM_MAX_DIM >= 4 - case 4: - result = ispace_from_dims<4>(src_piece_shape.dims) - .copy({src_field}, {dst_field}, requests, wait_on, priority); - break; -#endif -#if REALM_MAX_DIM >= 5 - case 5: - result = ispace_from_dims<5>(src_piece_shape.dims) - .copy({src_field}, {dst_field}, requests, wait_on, priority); - break; -#endif - default: - PANIC("TensorShape dims greater than REALM_MAX_DIM: {}", - src_piece_shape.dims.ff_ordered.num_dims()); - break; - } +Realm::Event + RealmContext::issue_reduction(ParallelTensorShape const &src_shape, + Realm::RegionInstance src_inst, + ParallelTensorShape const &dst_shape, + Realm::RegionInstance dst_inst, + redop_id_t redop_id, + bool is_fold, + bool exclusive, + Realm::ProfilingRequestSet const &requests, + Realm::Event wait_on, + int priority) { + TensorShape src_piece_shape = get_piece_shape(src_shape); + TensorShape dst_piece_shape = get_piece_shape(dst_shape); + ASSERT(src_piece_shape == dst_piece_shape); // For now, assume they match + + Realm::CopySrcDstField src_field; + src_field.set_field( + /*inst=*/src_inst, + /*field_id=*/0, + /*size=*/ + static_cast( + size_of_datatype(src_piece_shape.data_type).int_from_positive_int()), + /*subfield_offset=*/0); + Realm::CopySrcDstField dst_field; + dst_field.set_field( + /*inst=*/dst_inst, + /*field_id=*/0, + /*size=*/ + static_cast( + size_of_datatype(src_piece_shape.data_type).int_from_positive_int()), + /*subfield_offset=*/0); + dst_field.set_redop( + get_realm_reduction_op_id_for_redop_id(redop_id), is_fold, exclusive); + + Realm::Event result = issue_copy_for_field( + src_piece_shape.dims, src_field, dst_field, requests, wait_on, priority); this->outstanding_events.push_back(result); return result; } From b73751727c77c92358624fb9bd9997e3f77e73c3 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 21 May 2026 14:48:55 -0700 Subject: [PATCH 13/19] Replicate is not a task, don't represent it as one. --- .../include/realm-execution/tasks/task_id_t.dtg.toml | 9 --------- .../src/realm-execution/tasks/realm_task_registry.cc | 3 --- 2 files changed, 12 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml index b1e5e07e28..b0bcc23b4d 100644 --- a/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml +++ b/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml @@ -327,15 +327,6 @@ name = "COMBINE_FWD_TASK_ID" [[values]] name = "COMBINE_BWD_TASK_ID" -[[values]] -name = "REPLICATE_INIT_TASK_ID" - -[[values]] -name = "REPLICATE_FWD_TASK_ID" - -[[values]] -name = "REPLICATE_BWD_TASK_ID" - [[values]] name = "REDUCTION_INIT_TASK_ID" diff --git a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc index e7a8948f8d..dfdfe72ce0 100644 --- a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc +++ b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc @@ -49,7 +49,6 @@ Realm::Event register_all_tasks() { task_id_t::REDUCE_INIT_TASK_ID, task_id_t::REDUCTION_INIT_TASK_ID, task_id_t::REPARTITION_INIT_TASK_ID, - task_id_t::REPLICATE_INIT_TASK_ID, task_id_t::SOFTMAX_INIT_TASK_ID, }; @@ -86,7 +85,6 @@ Realm::Event register_all_tasks() { task_id_t::REDUCE_FWD_TASK_ID, task_id_t::REDUCTION_FWD_TASK_ID, task_id_t::REPARTITION_FWD_TASK_ID, - task_id_t::REPLICATE_FWD_TASK_ID, task_id_t::RESHAPE_FWD_TASK_ID, task_id_t::REVERSE_FWD_TASK_ID, task_id_t::SOFTMAX_FWD_TASK_ID, @@ -115,7 +113,6 @@ Realm::Event register_all_tasks() { task_id_t::REDUCE_BWD_TASK_ID, task_id_t::REDUCTION_BWD_TASK_ID, task_id_t::REPARTITION_BWD_TASK_ID, - task_id_t::REPLICATE_BWD_TASK_ID, task_id_t::RESHAPE_BWD_TASK_ID, task_id_t::REVERSE_BWD_TASK_ID, task_id_t::SOFTMAX_BWD_TASK_ID, From c536cc96bfd5b4b88739dcbcfcab9f00a96a7fed Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 21 May 2026 14:49:47 -0700 Subject: [PATCH 14/19] Put back the per device op state return code path. --- .../tasks/impl/per_device_op_state_init_task.cc | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/per_device_op_state_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/per_device_op_state_init_task.cc index 0ea51810e4..753fccf74b 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/per_device_op_state_init_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/per_device_op_state_init_task.cc @@ -66,17 +66,11 @@ void per_device_op_state_init_task_body(void const *args, result_state, ctx.get_current_device_idx())}; DeviceSpecificPtr result_device_specific{ ctx.get_current_device_idx(), result_state_ptr}; - - // replace spawn_per_device_op_state_init_return_task with: - // NOTE: SM/TODO: direct write assumes single-node shared address space - // For multi-node, replace with UserEvent trigger pattern - *task_args.origin_result_ptr = result_device_specific; - - // spawn_per_device_op_state_init_return_task(ctx, - // task_args.origin_proc, - // result_device_specific, - // task_args.origin_result_ptr, - // Realm::Event::NO_EVENT); + spawn_per_device_op_state_init_return_task(ctx, + task_args.origin_proc, + result_device_specific, + task_args.origin_result_ptr, + Realm::Event::NO_EVENT); } std::optional spawn_per_device_op_state_init_task( From 48673bc64d390ea29c412f3d78661e4bd80b7b7e Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Thu, 28 May 2026 03:50:13 -0700 Subject: [PATCH 15/19] Changes for breaking up the replicate PR testing code --- .../pcg_task_graph.dtg.toml | 18 +- .../abstracted_single_tensor_movement.cc | 2 +- ...racted_tensor_set_movement_across_split.cc | 11 +- ...substitution_and_update_machine_mapping.cc | 2 +- .../get_optimal_machine_mapping.cc | 4 +- .../get_tensor_set_movement_across_split.cc | 1 - .../machine_mapping/machine_mapping.cc | 8 +- .../machine_mapping_constraints.cc | 12 +- .../compiler/machine_mapping/machine_view.cc | 4 +- ...get_optimal_machine_mapping_with_memory.cc | 4 +- ...el_layer_guid_oblivious_machine_mapping.cc | 10 +- .../pcg/pcg_binary_sp_decomposition.cc | 2 +- .../task_graph_simulator/pcg_task_graph.cc | 17 +- .../task_graph_simulator/task_simulator.cc | 2 +- .../unity_algorithm/unity_algorithm.cc | 2 +- .../src/local-execution/tensor_allocation.cc | 2 +- .../src/op-attrs/get_incoming_tensor_roles.cc | 2 +- .../op-attrs/parallel_tensor_dim_degrees.cc | 6 +- .../parallel_tensor_space_coordinate.cc | 4 +- lib/op-attrs/src/op-attrs/shape_inference.cc | 2 +- .../mapped_operator_task_group.h | 6 +- .../mapped_parallel_computation_graph.h | 4 + .../mapped_parallel_layer_info.dtg.toml | 31 + ...ed_parallel_layer_invocation_info.dtg.toml | 33 + .../mapped_parallel_layer_invocation_info.h | 17 + .../parallel_computation_graph.h | 8 + .../parallel_layer_info.dtg.toml | 26 + .../parallel_layer_invocation_info.dtg.toml | 33 + .../parallel_tensor_info.dtg.toml | 26 + lib/pcg/src/pcg/computation_graph.cc | 2 +- lib/pcg/src/pcg/computation_graph_builder.cc | 8 +- .../mapped_operator_task_group.cc | 38 +- .../mapped_parallel_computation_graph.cc | 38 + .../mapped_parallel_layer_invocation_info.cc | 22 + .../parallel_computation_graph.cc | 48 +- .../parallel_computation_graph_builder.cc | 8 +- .../mapped_parallel_computation_graph.cc | 64 +- .../parallel_computation_graph_builder.cc | 4 +- .../realm-execution/instance_allocation.cc | 2 +- .../src/realm-execution/pcg_instance.cc | 4 +- .../serializable_tensor_instance_backing.cc | 12 +- .../src/realm-execution/test_op_replicate.cc | 2 + lib/runtime/src/parallel_tensor_uses.cc | 2 +- lib/runtime/src/tensor_uses.cc | 2 +- .../apply_substitution/apply_substitution.cc | 4 +- .../perform_shape_inference.cc | 4 +- .../src/substitutions/pcg_pattern_match.cc | 2 +- .../src/substitutions/substitution_builder.cc | 2 +- .../unlabelled/find_pattern_matches.cc | 7 +- .../unlabelled/pattern_matching.cc | 2 +- .../task-spec/dynamic_graph/copy_insertion.h | 5 + .../dynamic_graph/dynamic_node_invocation.h | 20 + ...mic_node_invocation_sharding_info.dtg.toml | 30 + .../dynamic_open_dataflow_graph.h | 6 + .../dynamic_value_attrs.dtg.toml | 29 +- .../dynamic_graph/dynamic_value_attrs.h | 4 + ...dynamic_value_attrs_sharding_info.dtg.toml | 27 + .../include/task-spec/dynamic_graph/index.dox | 1 + ...amic_open_dataflow_graph_from_mapped_pcg.h | 6 + .../task-spec/dynamic_graph/pass_expansion.h | 5 + .../serializable_dynamic_value_attrs.dtg.toml | 4 +- .../task-spec/dynamic_graph/shard_expansion.h | 35 +- .../dynamic_graph/update_insertion.h | 9 + .../task-spec/dynamic_graph/copy_insertion.cc | 125 ++- .../dynamic_graph/dynamic_node_invocation.cc | 35 + .../dynamic_open_dataflow_graph.cc | 13 + .../dynamic_graph/dynamic_value_attrs.cc | 13 + ...ake_dynamic_open_dataflow_graph_from_cg.cc | 2 +- ...mic_open_dataflow_graph_from_mapped_pcg.cc | 213 +---- .../task-spec/dynamic_graph/pass_expansion.cc | 18 + .../dynamic_graph/shard_expansion.cc | 347 ++++--- .../task-spec/dynamic_graph/copy_insertion.cc | 873 +++++++++++------- ...mic_open_dataflow_graph_from_mapped_pcg.cc | 396 ++++++++ .../dynamic_graph/shard_expansion.cc | 463 +++++----- .../archetypes/jsonable_ordered_value_type.h | 91 ++ .../{filter_keys.h => bidict_filter_keys.h} | 6 +- ...filter_values.h => bidict_filter_values.h} | 6 +- ...filtrans_keys.h => bidict_filtrans_keys.h} | 6 +- ...rans_values.h => bidict_filtrans_values.h} | 6 +- ...red_set_of.h => bidict_unordered_set_of.h} | 6 +- .../utils/bidict/algorithms/transform_keys.h | 2 +- .../bidict/algorithms/transform_values.h | 2 +- .../unstructured_relation_from_bidict.h | 4 +- lib/utils/include/utils/bidict/bidict.h | 42 +- lib/utils/include/utils/containers/all_of.h | 8 +- .../containers/binary_merge_disjoint_maps.h | 4 +- .../utils/containers/binary_merge_maps_with.h | 10 +- lib/utils/include/utils/containers/filter.h | 7 + .../include/utils/containers/generate_map.h | 6 +- .../utils/containers/generate_unordered_map.h | 28 + .../utils/containers/get_all_assignments.h | 4 +- lib/utils/include/utils/containers/index.dox | 2 +- .../include/utils/containers/is_submapeq_of.h | 4 +- lib/utils/include/utils/containers/items.h | 4 +- lib/utils/include/utils/containers/keys.h | 10 +- .../include/utils/containers/lookup_in_map.h | 16 +- .../utils/containers/map_from_unordered.h | 18 + .../include/utils/containers/map_keys2.h | 2 +- .../utils/containers/map_keys_and_values.h | 3 +- .../include/utils/containers/map_values.h | 13 + .../include/utils/containers/require_all_of.h | 33 + .../utils/containers/require_only_key.h | 9 + .../include/utils/containers/require_same.h | 2 +- .../include/utils/containers/transform.h | 14 +- .../utils/containers/unordered_items.h | 17 + .../include/utils/containers/unordered_keys.h | 30 + .../utils/containers/unordered_map_from_map.h | 18 + .../utils/containers/zip_values_strict.h | 8 +- .../utils/containers/zip_values_strict_with.h | 8 +- lib/utils/include/utils/fmt/map.h | 4 +- .../unordered_set_kwarg_dataflow_graph.h | 4 +- ...ordered_set_labelled_open_dataflow_graph.h | 17 +- ...d_set_labelled_open_kwarg_dataflow_graph.h | 23 +- .../unordered_set_open_kwarg_dataflow_graph.h | 4 +- .../algorithms/get_incoming_slots_for_node.h | 4 +- .../algorithms/get_outgoing_slots_for_node.h | 4 +- .../algorithms/get_graph_data.h | 4 +- .../algorithms/permute_input_ids.h | 4 +- .../algorithms/permute_node_ids.h | 6 +- .../algorithms/rewrite_labels.h | 6 +- ..._labelled_open_kwarg_dataflow_graph_data.h | 6 +- .../labelled_open_kwarg_dataflow_graph_data.h | 4 +- ...lled_open_kwarg_dataflow_graph_input_ids.h | 6 +- ...elled_open_kwarg_dataflow_graph_node_ids.h | 6 +- ...abelled_open_kwarg_dataflow_graph_labels.h | 6 +- .../graph/series_parallel/get_ancestors.h | 1 + .../non_normal_parallel_split.dtg.toml | 9 +- .../non_normal_series_split.dtg.toml | 2 + .../non_normal_sp_decomposition.dtg.toml | 1 + .../series_parallel/parallel_split.dtg.toml | 9 +- .../series_parallel_decomposition.dtg.toml | 1 + .../series_parallel/series_split.dtg.toml | 2 + .../include/utils/json/check_is_jsonable.h | 4 +- .../include/utils/many_to_one/many_to_one.h | 29 +- .../many_to_one_from_unstructured_relation.h | 20 - .../unstructured_relation_from_many_to_one.h | 17 - .../include/utils/nonempty_set/nonempty_set.h | 136 +++ .../include/utils/one_to_many/one_to_many.h | 81 +- ...d_relation.h => one_to_many_filter_keys.h} | 15 +- .../one_to_many/one_to_many_filter_values.h | 22 + .../one_to_many_transform_values.h | 6 +- .../unstructured_relation_from_one_to_many.h | 21 - lib/utils/include/utils/orthotope/dim_coord.h | 8 +- .../include/utils/orthotope/dim_domain.h | 4 +- .../utils/orthotope/dim_projection.dtg.toml | 1 + .../utils/orthotope/down_projection.dtg.toml | 1 + .../include/utils/orthotope/down_projection.h | 2 +- .../utils/orthotope/eq_projection.dtg.toml | 1 + .../utils/orthotope/minimal_dim_domain.h | 8 +- .../utils/orthotope/up_projection.dtg.toml | 1 + .../include/utils/orthotope/up_projection.h | 4 +- .../archetypes/jsonable_ordered_value_type.cc | 7 + .../{filter_keys.cc => bidict_filter_keys.cc} | 4 +- ...lter_values.cc => bidict_filter_values.cc} | 4 +- ...ltrans_keys.cc => bidict_filtrans_keys.cc} | 4 +- ...ns_values.cc => bidict_filtrans_values.cc} | 4 +- .../algorithms/bidict_unordered_set_of.cc | 11 + .../bidict/algorithms/unordered_set_of.cc | 11 - lib/utils/src/utils/cli/cli_parse.cc | 4 +- lib/utils/src/utils/containers/filter.cc | 44 + .../src/utils/containers/generate_map.cc | 12 + .../containers/generate_unordered_map.cc | 13 + lib/utils/src/utils/containers/group_by.cc | 4 +- .../src/utils/containers/is_submapeq_of.cc | 11 + lib/utils/src/utils/containers/items.cc | 13 + lib/utils/src/utils/containers/keys.cc | 12 + .../utils/containers/map_from_unordered.cc | 13 + lib/utils/src/utils/containers/map_values.cc | 8 + .../src/utils/containers/require_all_of.cc | 34 + .../src/utils/containers/require_only_key.cc | 5 + lib/utils/src/utils/containers/transform.cc | 46 + .../src/utils/containers/unordered_items.cc | 15 + .../src/utils/containers/unordered_keys.cc | 13 + .../containers/unordered_map_from_map.cc | 12 + .../algorithms/dataflow_graph_as_dot.cc | 1 - .../digraph/algorithms/get_dominators_map.cc | 4 +- .../algorithms/get_imm_dominators_map.cc | 8 +- .../algorithms/get_imm_post_dominator.cc | 4 +- .../digraph/algorithms/get_incoming_edges.cc | 24 +- .../digraph/algorithms/get_outgoing_edges.cc | 25 +- .../graph/digraph/algorithms/is_acyclic.cc | 4 +- .../graph/instances/adjacency_multidigraph.cc | 18 +- .../algorithms/get_incoming_edges.cc | 7 +- .../get_multidiedge_to_diedge_map.cc | 4 +- .../algorithms/get_outgoing_edges.cc | 7 +- .../algorithms/get_incoming_edges.cc | 4 +- .../balanced_binary_sp_tree_from_nary.cc | 5 +- .../non_normal_sp_decomposition.cc | 10 +- .../normalize_sp_decomposition.cc | 5 +- .../series_parallel_decomposition.cc | 7 +- .../series_parallel_metrics.cc | 2 + .../sp_ization/escribano_algo.cc | 4 +- .../sp_ization/flexible_algo.cc | 6 +- .../sp_ization/naive_stratum_sync.cc | 4 +- .../series_parallel/sp_ization/node_role.cc | 4 +- .../many_to_one/exhaustive_relational_join.cc | 8 +- .../utils/many_to_one/invert_many_to_one.cc | 6 +- .../src/utils/many_to_one/many_to_one.cc | 12 +- .../many_to_one_from_unstructured_relation.cc | 12 - .../unstructured_relation_from_many_to_one.cc | 12 - .../src/utils/nonempty_set/nonempty_set.cc | 23 + .../one_to_many/exhaustive_relational_join.cc | 8 +- .../utils/one_to_many/invert_one_to_many.cc | 6 +- .../src/utils/one_to_many/one_to_many.cc | 20 +- .../one_to_many/one_to_many_filter_keys.cc | 12 + .../one_to_many/one_to_many_filter_values.cc | 13 + .../one_to_many/one_to_many_from_bidict.cc | 6 +- .../one_to_many_from_l_to_r_mapping.cc | 6 +- .../one_to_many_from_unstructured_relation.cc | 12 - .../one_to_many_transform_values.cc | 8 +- .../unstructured_relation_from_one_to_many.cc | 12 - .../src/utils/orthotope/dim_domain_mapping.cc | 14 +- .../src/utils/orthotope/dim_projection.cc | 11 +- .../src/utils/orthotope/down_projection.cc | 12 +- .../src/utils/orthotope/eq_projection.cc | 12 +- .../orthotope/minimal_dim_domain_mapping.cc | 14 +- .../src/utils/orthotope/up_projection.cc | 7 +- .../{filter_keys.cc => bidict_filter_keys.cc} | 7 +- ...lter_values.cc => bidict_filter_values.cc} | 6 +- ...ltrans_keys.cc => bidict_filtrans_keys.cc} | 6 +- ...ns_values.cc => bidict_filtrans_values.cc} | 8 +- ...d_set_of.cc => bidict_unordered_set_of.cc} | 2 +- lib/utils/test/src/utils/bidict/bidict.cc | 2 + .../test/src/utils/containers/enumerate.cc | 4 +- lib/utils/test/src/utils/containers/keys.cc | 8 +- .../algorithms/get_imm_post_dominators_map.cc | 1 - .../sp_ization/work_duplicating_sp_ization.cc | 4 +- .../test/src/utils/many_to_one/many_to_one.cc | 79 ++ .../many_to_one_from_unstructured_relation.cc | 54 -- .../unstructured_relation_from_many_to_one.cc | 25 - .../test/src/utils/one_to_many/one_to_many.cc | 107 ++- .../one_to_many_from_unstructured_relation.cc | 53 -- .../unstructured_relation_from_one_to_many.cc | 25 - 233 files changed, 3602 insertions(+), 1714 deletions(-) create mode 100644 lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_info.dtg.toml create mode 100644 lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_invocation_info.dtg.toml create mode 100644 lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_invocation_info.h create mode 100644 lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_info.dtg.toml create mode 100644 lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_invocation_info.dtg.toml create mode 100644 lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_info.dtg.toml create mode 100644 lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_invocation_info.cc create mode 100644 lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_invocation.h create mode 100644 lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_invocation_sharding_info.dtg.toml create mode 100644 lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs_sharding_info.dtg.toml create mode 100644 lib/task-spec/src/task-spec/dynamic_graph/dynamic_node_invocation.cc create mode 100644 lib/task-spec/test/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc create mode 100644 lib/utils/include/utils/archetypes/jsonable_ordered_value_type.h rename lib/utils/include/utils/bidict/algorithms/{filter_keys.h => bidict_filter_keys.h} (53%) rename lib/utils/include/utils/bidict/algorithms/{filter_values.h => bidict_filter_values.h} (53%) rename lib/utils/include/utils/bidict/algorithms/{filtrans_keys.h => bidict_filtrans_keys.h} (64%) rename lib/utils/include/utils/bidict/algorithms/{filtrans_values.h => bidict_filtrans_values.h} (63%) rename lib/utils/include/utils/bidict/algorithms/{unordered_set_of.h => bidict_unordered_set_of.h} (51%) create mode 100644 lib/utils/include/utils/containers/generate_unordered_map.h create mode 100644 lib/utils/include/utils/containers/map_from_unordered.h create mode 100644 lib/utils/include/utils/containers/require_all_of.h create mode 100644 lib/utils/include/utils/containers/unordered_items.h create mode 100644 lib/utils/include/utils/containers/unordered_keys.h create mode 100644 lib/utils/include/utils/containers/unordered_map_from_map.h delete mode 100644 lib/utils/include/utils/many_to_one/many_to_one_from_unstructured_relation.h delete mode 100644 lib/utils/include/utils/many_to_one/unstructured_relation_from_many_to_one.h create mode 100644 lib/utils/include/utils/nonempty_set/nonempty_set.h rename lib/utils/include/utils/one_to_many/{one_to_many_from_unstructured_relation.h => one_to_many_filter_keys.h} (50%) create mode 100644 lib/utils/include/utils/one_to_many/one_to_many_filter_values.h delete mode 100644 lib/utils/include/utils/one_to_many/unstructured_relation_from_one_to_many.h create mode 100644 lib/utils/src/utils/archetypes/jsonable_ordered_value_type.cc rename lib/utils/src/utils/bidict/algorithms/{filter_keys.cc => bidict_filter_keys.cc} (58%) rename lib/utils/src/utils/bidict/algorithms/{filter_values.cc => bidict_filter_values.cc} (57%) rename lib/utils/src/utils/bidict/algorithms/{filtrans_keys.cc => bidict_filtrans_keys.cc} (61%) rename lib/utils/src/utils/bidict/algorithms/{filtrans_values.cc => bidict_filtrans_values.cc} (61%) create mode 100644 lib/utils/src/utils/bidict/algorithms/bidict_unordered_set_of.cc delete mode 100644 lib/utils/src/utils/bidict/algorithms/unordered_set_of.cc create mode 100644 lib/utils/src/utils/containers/generate_unordered_map.cc create mode 100644 lib/utils/src/utils/containers/map_from_unordered.cc create mode 100644 lib/utils/src/utils/containers/require_all_of.cc create mode 100644 lib/utils/src/utils/containers/unordered_items.cc create mode 100644 lib/utils/src/utils/containers/unordered_keys.cc create mode 100644 lib/utils/src/utils/containers/unordered_map_from_map.cc delete mode 100644 lib/utils/src/utils/many_to_one/many_to_one_from_unstructured_relation.cc delete mode 100644 lib/utils/src/utils/many_to_one/unstructured_relation_from_many_to_one.cc create mode 100644 lib/utils/src/utils/nonempty_set/nonempty_set.cc create mode 100644 lib/utils/src/utils/one_to_many/one_to_many_filter_keys.cc create mode 100644 lib/utils/src/utils/one_to_many/one_to_many_filter_values.cc delete mode 100644 lib/utils/src/utils/one_to_many/one_to_many_from_unstructured_relation.cc delete mode 100644 lib/utils/src/utils/one_to_many/unstructured_relation_from_one_to_many.cc rename lib/utils/test/src/utils/bidict/algorithms/{filter_keys.cc => bidict_filter_keys.cc} (64%) rename lib/utils/test/src/utils/bidict/algorithms/{filter_values.cc => bidict_filter_values.cc} (61%) rename lib/utils/test/src/utils/bidict/algorithms/{filtrans_keys.cc => bidict_filtrans_keys.cc} (75%) rename lib/utils/test/src/utils/bidict/algorithms/{filtrans_values.cc => bidict_filtrans_values.cc} (69%) rename lib/utils/test/src/utils/bidict/algorithms/{unordered_set_of.cc => bidict_unordered_set_of.cc} (77%) delete mode 100644 lib/utils/test/src/utils/many_to_one/many_to_one_from_unstructured_relation.cc delete mode 100644 lib/utils/test/src/utils/many_to_one/unstructured_relation_from_many_to_one.cc delete mode 100644 lib/utils/test/src/utils/one_to_many/one_to_many_from_unstructured_relation.cc delete mode 100644 lib/utils/test/src/utils/one_to_many/unstructured_relation_from_one_to_many.cc diff --git a/lib/compiler/include/compiler/task_graph_simulator/pcg_task_graph.dtg.toml b/lib/compiler/include/compiler/task_graph_simulator/pcg_task_graph.dtg.toml index 2c5b5f56fc..31b87feb31 100644 --- a/lib/compiler/include/compiler/task_graph_simulator/pcg_task_graph.dtg.toml +++ b/lib/compiler/include/compiler/task_graph_simulator/pcg_task_graph.dtg.toml @@ -7,19 +7,19 @@ features = [ includes = [ "utils/graph/digraph/digraph_view.h", - "utils/bidict/bidict.h", "compiler/task_graph_simulator/pcg_task.dtg.h", "pcg/device_id_t.dtg.h", + "utils/many_to_one/many_to_one.h", "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", - "", - "" + "", + "" ] src_includes = [ - "utils/fmt/unordered_set.h", - "utils/hash/unordered_set.h", - "utils/fmt/unordered_map.h", - "utils/hash/unordered_map.h" + "utils/fmt/set.h", + "utils/hash/set.h", + "utils/fmt/map.h", + "utils/hash/map.h" ] [[fields]] @@ -28,8 +28,8 @@ type = "::FlexFlow::DiGraphView" [[fields]] name = "node_to_task" -type = "::FlexFlow::bidict<::FlexFlow::Node, ::FlexFlow::PCGTask>" +type = "::FlexFlow::ManyToOne<::FlexFlow::Node, ::FlexFlow::PCGTask>" [[fields]] name = "node_to_devices" -type = "std::unordered_map<::FlexFlow::Node, std::unordered_set<::FlexFlow::device_id_t>>" +type = "std::map<::FlexFlow::Node, std::set<::FlexFlow::device_id_t>>" diff --git a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.cc b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.cc index 1aeb83d202..5f9300973f 100644 --- a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.cc +++ b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.cc @@ -15,7 +15,7 @@ std::unordered_set abstracted_single_tensor_movement_get_dst_layers( AbstractedSingleTensorMovement const &m) { return transform( - keys(m.edge_to_size), + unordered_keys(m.edge_to_size), [](AbstractedSingleTensorCommunicationEdge const &e) -> BinaryTreePath { return e.dst.operator_tree_path; }); diff --git a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc index 151008f65f..6ff261facd 100644 --- a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc +++ b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc @@ -10,10 +10,9 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" -#include "utils/bidict/algorithms/unordered_set_of.h" +#include "utils/bidict/algorithms/bidict_unordered_set_of.h" #include "utils/containers/binary_cartesian_product.h" #include "utils/containers/flatmap.h" -#include "utils/containers/generate_map.h" #include "utils/containers/get_only.h" #include "utils/containers/group_by.h" #include "utils/containers/map_from_pairs.h" @@ -46,7 +45,7 @@ AbstractedSingleTensorMovement get_abstracted_single_tensor_movement_along_edge( std::unordered_map single_comms = map_from_pairs(transform( - unordered_set_of(coord_mapping), + bidict_unordered_set_of(coord_mapping), [&](std::pair const & src_dst) -> std::pair { @@ -101,9 +100,9 @@ AbstractedTensorSetMovement get_abstracted_tensor_set_movement_across_split( }; return AbstractedTensorSetMovement{ - transform(edges_by_tensor.right_groups(), - [&](nonempty_unordered_set const - &edges) { + transform(unordered_set_of(edges_by_tensor.right_groups()), + [&](nonempty_set const &edges) + { return merge_abstracted_single_tensor_movements(transform( unordered_multiset_of(edges.unwrap_as_unordered_set()), to_abstracted_single_tensor_movement)); diff --git a/lib/compiler/src/compiler/machine_mapping/apply_substitution_and_update_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/apply_substitution_and_update_machine_mapping.cc index 7ccab2fac9..4e38750de3 100644 --- a/lib/compiler/src/compiler/machine_mapping/apply_substitution_and_update_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/apply_substitution_and_update_machine_mapping.cc @@ -59,7 +59,7 @@ SearchResult apply_substitution_and_update_machine_mapping( select_random(substituted_machine_views)); } - ASSERT(is_subseteq_of(keys(post_node_data), keys(machine_views))); + ASSERT(is_subseteq_of(unordered_keys(post_node_data), unordered_keys(machine_views))); std::unordered_map post_node_machine_views = diff --git a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index 77e50740aa..48f3bd9eed 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -23,7 +23,7 @@ #include "utils/containers/contains.h" #include "utils/containers/contains_key.h" #include "utils/containers/flatmap.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/containers/get_all_assignments.h" #include "utils/containers/keys.h" #include "utils/containers/set_minus.h" @@ -103,7 +103,7 @@ MachineMappingResult set_minus(boundary_layers, get_constrained_layers(sub_constraints)); std::unordered_map> - allowed = generate_map( + allowed = generate_unordered_map( unconstrained_boundary_layers, [&](BinaryTreePath const &l) -> std::unordered_set { UnmappedRuntimeOnlyOpCostEstimateKey leaf = diff --git a/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc b/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc index 7d1f28337c..f7dbdd1d05 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc @@ -7,7 +7,6 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" -#include "utils/containers/generate_map.h" #include "utils/containers/keys.h" #include "utils/containers/map_values.h" #include "utils/containers/sum.h" diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc index a2307716ba..861912efef 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc @@ -8,7 +8,7 @@ #include "utils/bidict/algorithms/bidict_from_map.h" #include "utils/containers/are_disjoint.h" #include "utils/containers/binary_merge_disjoint_maps.h" -#include "utils/containers/keys.h" +#include "utils/containers/unordered_keys.h" namespace FlexFlow { @@ -20,7 +20,7 @@ MappedParallelComputationGraph get_parallel_layers(pcg); std::unordered_set mapped_layers = - keys(mapping.machine_views); + unordered_keys(mapping.machine_views); ASSERT(mapped_layers == pcg_layers); @@ -40,7 +40,7 @@ MappedParallelComputationGraph }; std::unordered_map - mapped_op_task_groups = generate_map(mapped_layers, mapping_for_layer); + mapped_op_task_groups = generate_unordered_map(mapped_layers, mapping_for_layer); return mapped_pcg_from_pcg_and_mapped_op_task_groups(pcg, mapped_op_task_groups); @@ -54,7 +54,7 @@ MachineMapping combine_disjoint_mappings(MachineMapping const &m1, } bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2) { - return are_disjoint(keys(m1.machine_views), keys(m2.machine_views)); + return are_disjoint(unordered_keys(m1.machine_views), unordered_keys(m2.machine_views)); } std::optional get_machine_mapping_from_machine_mapping_result( diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc index 8278d5511c..fe92c77def 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc @@ -3,7 +3,7 @@ #include "utils/containers/filter_values.h" #include "utils/containers/filtermap_keys.h" #include "utils/containers/flatmap.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/containers/keys.h" #include "utils/containers/map_values.h" #include "utils/containers/restrict_keys.h" @@ -14,7 +14,7 @@ namespace FlexFlow { MachineMappingConstraints get_unconstrained_solution_for_layers( std::unordered_set const &layers) { return MachineMappingConstraints{ - generate_map(layers, + generate_unordered_map(layers, [](BinaryTreePath const &) -> std::optional { return std::nullopt; }), @@ -24,7 +24,7 @@ MachineMappingConstraints get_unconstrained_solution_for_layers( std::unordered_set get_unconstrained_layers(MachineMappingConstraints const &constraints) { - return keys(filter_values( + return unordered_keys(filter_values( constraints.machine_views, [](std::optional const &mv) { return !mv.has_value(); })); } @@ -32,14 +32,14 @@ std::unordered_set std::unordered_set get_constrained_layers(MachineMappingConstraints const &constraints) { - return keys(filter_values( + return unordered_keys(filter_values( constraints.machine_views, [](std::optional const &mv) { return mv.has_value(); })); } std::unordered_set get_all_layers(MachineMappingConstraints const &partial_solution) { - return keys(partial_solution.machine_views); + return unordered_keys(partial_solution.machine_views); } std::optional get_machine_view_for_layer( @@ -103,7 +103,7 @@ MachineMappingConstraints with_additional_constraints( std::optional require_only_root(MachineMappingConstraints const &constraints) { - ASSERT(keys(constraints.machine_views) == + ASSERT(unordered_keys(constraints.machine_views) == std::unordered_set{binary_tree_root_path()}, fmt::format("require_only_root expected constraints to have only a " "single key (the root path), but received {}", diff --git a/lib/compiler/src/compiler/machine_mapping/machine_view.cc b/lib/compiler/src/compiler/machine_mapping/machine_view.cc index 090dec5845..7d00707b0d 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_view.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_view.cc @@ -221,8 +221,8 @@ static OperatorAtomicTaskShardBinding mappings = get_operator_to_ptensor_mappings(op_attrs, inputs_dim_degrees); std::unordered_map - ptensor_coords = generate_map( - keys(inputs_dim_degrees), + ptensor_coords = generate_unordered_map( + unordered_keys(inputs_dim_degrees), [&](TensorSlotName const &slot_name) -> ParallelTensorSpaceCoordinate { num_ptensor_shard_dims_t num_shard_dims = diff --git a/lib/compiler/src/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.cc b/lib/compiler/src/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.cc index a3f2009a60..dad91dd317 100644 --- a/lib/compiler/src/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.cc +++ b/lib/compiler/src/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.cc @@ -17,7 +17,7 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "utils/containers/contains.h" #include "utils/containers/flatmap.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/containers/get_all_assignments.h" #include "utils/containers/unordered_set_of.h" #include "utils/exception.h" @@ -84,7 +84,7 @@ MachineMappingWithMemoryResult get_optimal_machine_mapping_with_memory( std::unordered_set const &boundary_layers) -> std::unordered_set { std::unordered_map> - allowed = generate_map( + allowed = generate_unordered_map( boundary_layers, [&](BinaryTreePath const &l) -> std::unordered_set { UnmappedRuntimeOnlyOpCostEstimateKey leaf = diff --git a/lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc index 6e2096afcc..ac39021f6f 100644 --- a/lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc @@ -45,7 +45,7 @@ std::unordered_map PCGBinarySPDecomposition const &decomposition, ParallelLayerGuidObliviousMachineMapping const &mapping) { std::unordered_set leaf_paths = require_same( - pcg_sp_tree_get_all_leaf_paths(decomposition), keys(mapping.raw_mapping)); + pcg_sp_tree_get_all_leaf_paths(decomposition), unordered_keys(mapping.raw_mapping)); std::unordered_map path_to_op_task_space_map = @@ -54,7 +54,7 @@ std::unordered_map return get_operator_task_space(pcg, l); }); - return generate_map( + return generate_unordered_map( leaf_paths, [&](BinaryTreePath const &p) -> MachineSpaceStencil { return MachineSpaceStencil{ /*operator_task_space=*/path_to_op_task_space_map.at(p), @@ -71,12 +71,12 @@ std::unordered_map> std::unordered_map tree_leaf_map = mm_problem_tree_get_path_to_leaf_map(tree); - std::unordered_set mapping_paths = keys(mapping.raw_mapping); - std::unordered_set tree_paths = keys(tree_leaf_map); + std::unordered_set mapping_paths = unordered_keys(mapping.raw_mapping); + std::unordered_set tree_paths = unordered_keys(tree_leaf_map); ASSERT(is_subseteq_of(mapping_paths, tree_paths)); - return generate_map( + return generate_unordered_map( tree_paths, [&](BinaryTreePath const &p) -> std::optional { if (!contains_key(mapping.raw_mapping, p)) { diff --git a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc index cd8e634f2c..4d1c88d9eb 100644 --- a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc +++ b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc @@ -152,7 +152,7 @@ SPDecompositionTreeNodeType std::unordered_set pcg_sp_tree_get_all_leaf_paths(PCGBinarySPDecomposition const &tree) { - return keys(pcg_sp_tree_get_path_to_leaf_map(tree)); + return unordered_keys(pcg_sp_tree_get_path_to_leaf_map(tree)); } std::unordered_set diff --git a/lib/compiler/src/compiler/task_graph_simulator/pcg_task_graph.cc b/lib/compiler/src/compiler/task_graph_simulator/pcg_task_graph.cc index d4d5a78d6a..b016b106e9 100644 --- a/lib/compiler/src/compiler/task_graph_simulator/pcg_task_graph.cc +++ b/lib/compiler/src/compiler/task_graph_simulator/pcg_task_graph.cc @@ -15,6 +15,7 @@ #include "utils/graph/instances/adjacency_digraph.h" #include #include +#include "utils/containers/set_of.h" namespace FlexFlow { @@ -23,21 +24,21 @@ PCGTaskGraph MachineMapping const &machine_mapping, MachineComputeSpecification const &machine_spec) { DiGraph digraph = DiGraph::create(); - bidict node_to_task; + ManyToOne node_to_task; bidict node_to_layer; - std::unordered_map> node_to_devices; + std::map> node_to_devices; for (parallel_layer_guid_t const &layer : get_parallel_layers(pcg)) { MachineView mv = machine_mapping.machine_views.at(layer); RuntimeOnlyOpCostEstimateKey op_key = get_mapped_runtime_only_op_cost_estimate_key_for_layer(pcg, layer, mv); Node node = digraph.add_node(); - node_to_task.equate(node, PCGTask{op_key}); - node_to_layer.equate(node, layer); + node_to_task.insert({node, PCGTask{op_key}}); + node_to_layer.equate_strict(node, layer); node_to_devices[node] = - get_device_ids(get_operator_task_space(pcg, layer), - machine_mapping.machine_views.at(layer), - machine_spec); + set_of(get_device_ids(get_operator_task_space(pcg, layer), + machine_mapping.machine_views.at(layer), + machine_spec)); } for (ParallelComputationGraphEdge const &edge : get_edges(pcg)) { @@ -46,7 +47,7 @@ PCGTaskGraph TensorSetMovement movement = get_tensor_set_movement_from_pcg_edge(edge, pcg, src_mv, dst_mv); Node node = digraph.add_node(); - node_to_task.equate(node, PCGTask{movement}); + node_to_task.insert({node, PCGTask{movement}}); node_to_devices[node] = {}; Node src_node = node_to_layer.at_r(get_src_layer(edge)); Node dst_node = node_to_layer.at_r(get_dst_layer(edge)); diff --git a/lib/compiler/src/compiler/task_graph_simulator/task_simulator.cc b/lib/compiler/src/compiler/task_graph_simulator/task_simulator.cc index bc528493a8..3fabfc3966 100644 --- a/lib/compiler/src/compiler/task_graph_simulator/task_simulator.cc +++ b/lib/compiler/src/compiler/task_graph_simulator/task_simulator.cc @@ -62,7 +62,7 @@ milliseconds_t task_simulator_estimate_forward_pass_time( std::unordered_set devices_occupied = set_union(transform(in_progress_tasks, get_devices)); - std::unordered_set required_devices = get_devices(task); + std::unordered_set required_devices = unordered_set_of(get_devices(task)); return intersection(devices_occupied, required_devices).empty(); }; diff --git a/lib/compiler/src/compiler/unity_algorithm/unity_algorithm.cc b/lib/compiler/src/compiler/unity_algorithm/unity_algorithm.cc index be8c7c4f98..05a049e98f 100644 --- a/lib/compiler/src/compiler/unity_algorithm/unity_algorithm.cc +++ b/lib/compiler/src/compiler/unity_algorithm/unity_algorithm.cc @@ -20,7 +20,7 @@ #include "substitutions/sub_parallel_computation_graph.h" #include "substitutions/substitution.h" #include "substitutions/unity_substitution_set.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/deduplicated_priority_queue.h" #include "utils/graph/node/algorithms.h" #include "utils/optional.h" diff --git a/lib/local-execution/src/local-execution/tensor_allocation.cc b/lib/local-execution/src/local-execution/tensor_allocation.cc index bb2a1ba2a4..203345a9af 100644 --- a/lib/local-execution/src/local-execution/tensor_allocation.cc +++ b/lib/local-execution/src/local-execution/tensor_allocation.cc @@ -55,7 +55,7 @@ DynamicOpenDataflowGraph perform_tensor_allocation( Allocator &allocator) { ASSERT(no_tensors_are_allocated(g)); ASSERT(tensors_are_ready_for_allocation(g)); - for (DynamicValueAttrs const &v : keys(preallocated)) { + for (DynamicValueAttrs const &v : unordered_keys(preallocated)) { ASSERT(v.accessor == std::nullopt); } diff --git a/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc b/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc index eec9ae869c..89936d9b00 100644 --- a/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc +++ b/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc @@ -46,7 +46,7 @@ std::unordered_map }; }, [&](ConcatAttrs const &) { - return generate_map(get_variadic_inputs_slot_name_sequence(), + return generate_unordered_map(get_variadic_inputs_slot_name_sequence(), [](TensorSlotName) -> IncomingTensorRole { return IncomingTensorRole::INPUT; }); diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dim_degrees.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dim_degrees.cc index 51d7968033..83a7aded6a 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dim_degrees.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dim_degrees.cc @@ -8,7 +8,7 @@ #include "utils/containers/binary_merge_disjoint_maps.h" #include "utils/containers/filtermap_keys.h" #include "utils/containers/filtrans.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/containers/get_all_assignments.h" #include "utils/containers/map_keys.h" #include "utils/containers/map_values.h" @@ -96,7 +96,7 @@ std::unordered_map }; std::unordered_map shard_dim_degrees = - generate_map(get_idxs(degrees.shard_degrees), [&](ff_dim_t const &dim) { + generate_unordered_map(get_idxs(degrees.shard_degrees), [&](ff_dim_t const &dim) { return degrees.shard_degrees.at(dim); }); @@ -131,7 +131,7 @@ DimDomain ParallelTensorDimDegrees const &dim_degrees) { return DimDomain{ - generate_map(get_parallel_tensor_dim_indices(dim_degrees), + generate_unordered_map(get_parallel_tensor_dim_indices(dim_degrees), [&](parallel_tensor_dim_idx_t idx) { return get_degree_for_parallel_tensor_dim_idx(dim_degrees, idx); diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_space_coordinate.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_space_coordinate.cc index 0c6e157697..79d765b02c 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_space_coordinate.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_space_coordinate.cc @@ -3,7 +3,7 @@ #include "op-attrs/parallel_tensor_dim_idx_t.h" #include "utils/containers/contains_key.h" #include "utils/containers/filtermap_keys.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/containers/unordered_set_of.h" #include "utils/nonnegative_int/num_elements.h" @@ -73,7 +73,7 @@ DimCoord dim_coord_from_parallel_tensor_space_coord( ParallelTensorSpaceCoordinate const &coord) { return DimCoord{ - generate_map(get_dim_idxs_in_ptensor_space_coord(coord), + generate_unordered_map(get_dim_idxs_in_ptensor_space_coord(coord), [&](parallel_tensor_dim_idx_t idx) { return ptensor_coord_component_for_ptensor_dim_idx(coord, idx); diff --git a/lib/op-attrs/src/op-attrs/shape_inference.cc b/lib/op-attrs/src/op-attrs/shape_inference.cc index a3f8066dee..38e116bfd8 100644 --- a/lib/op-attrs/src/op-attrs/shape_inference.cc +++ b/lib/op-attrs/src/op-attrs/shape_inference.cc @@ -52,7 +52,7 @@ static std::vector std::vector expected_slots = slice(slots, 0, v_num_slots.unwrap_nonnegative()); - ASSERT(unordered_set_of(expected_slots) == keys(v)); + ASSERT(unordered_set_of(expected_slots) == unordered_keys(v)); return transform(expected_slots, [&](TensorSlotName const &slot_name) { return v.at(slot_name); diff --git a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h index aded1eb657..41aca802e7 100644 --- a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h +++ b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h @@ -7,6 +7,7 @@ #include "pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.dtg.h" #include "utils/bidict/bidict.h" #include +#include "utils/one_to_many/one_to_many.h" namespace FlexFlow { @@ -38,10 +39,13 @@ struct MappedOperatorTaskGroup { friend struct ::std::hash; }; -bidict +OneToMany get_tensor_bindings_for_slot_name(MappedOperatorTaskGroup const &, TensorSlotName const &); +std::set get_slot_names_for_task_group(MappedOperatorTaskGroup const &); + + nlohmann::json mapped_operator_task_group_as_dot_json(MappedOperatorTaskGroup const &); diff --git a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h index a2afdb7914..2e789e14c9 100644 --- a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h +++ b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h @@ -3,12 +3,16 @@ #include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_layer_invocation_info.dtg.h" namespace FlexFlow { std::unordered_set mpcg_get_parallel_layers(MappedParallelComputationGraph const &); +std::set + mpcg_get_invocation_set(MappedParallelComputationGraph const &); + MappedOperatorTaskGroup mpcg_get_mapping_for_layer(MappedParallelComputationGraph const &, parallel_layer_guid_t); diff --git a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_info.dtg.toml b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_info.dtg.toml new file mode 100644 index 0000000000..056cbecdde --- /dev/null +++ b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_info.dtg.toml @@ -0,0 +1,31 @@ +namespace = "FlexFlow" +name = "MappedParallelLayerInfo" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h", + "pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h", +] + +src_includes = [ +] + +[[fields]] +name = "guid" +type = "::FlexFlow::parallel_layer_guid_t" + +[[fields]] +name = "attrs" +type = "::FlexFlow::ParallelLayerAttrs" + +[[fields]] +name = "mapping" +type = "::FlexFlow::MappedOperatorTaskGroup" diff --git a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_invocation_info.dtg.toml b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_invocation_info.dtg.toml new file mode 100644 index 0000000000..b28884d551 --- /dev/null +++ b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_invocation_info.dtg.toml @@ -0,0 +1,33 @@ +namespace = "FlexFlow" +name = "MappedParallelLayerInvocationInfo" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "pcg/mapped_parallel_computation_graph/mapped_parallel_layer_info.dtg.h", + "pcg/parallel_computation_graph/parallel_tensor_info.dtg.h", + "op-attrs/tensor_slot_name.dtg.h", +] + +src_includes = [ + "utils/fmt/map.h", + "utils/hash/map.h", +] + +[[fields]] +name = "incoming" +type = "std::map<::FlexFlow::TensorSlotName, ::FlexFlow::ParallelTensorInfo>" + +[[fields]] +name = "layer_info" +type = "::FlexFlow::MappedParallelLayerInfo" + +[[fields]] +name = "outgoing" +type = "std::map<::FlexFlow::TensorSlotName, ::FlexFlow::ParallelTensorInfo>" diff --git a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_invocation_info.h b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_invocation_info.h new file mode 100644 index 0000000000..dcda3a977b --- /dev/null +++ b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_invocation_info.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_MAPPED_PARALLEL_COMPUTATION_GRAPH_MAPPED_PARALLEL_LAYER_INVOCATION_INFO_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_MAPPED_PARALLEL_COMPUTATION_GRAPH_MAPPED_PARALLEL_LAYER_INVOCATION_INFO_H + +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_layer_invocation_info.dtg.h" +#include "pcg/parallel_computation_graph/parallel_layer_invocation_info.dtg.h" +#include "pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h" + +namespace FlexFlow { + +MappedParallelLayerInvocationInfo + mapped_parallel_layer_invocation_info_from_pcg_invocation_and_mapping( + ParallelLayerInvocationInfo const &, + MappedOperatorTaskGroup const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h index 9764e40627..7c5a825420 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h @@ -12,6 +12,7 @@ #include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" #include "pcg/parallel_computation_graph/parallel_tensor_use_t.dtg.h" #include +#include "pcg/parallel_computation_graph/parallel_layer_invocation_info.dtg.h" namespace FlexFlow { @@ -38,6 +39,13 @@ ParallelLayerAddedResult OperatorTaskSpace get_operator_task_space(ParallelComputationGraph const &pcg, parallel_layer_guid_t const &layer); +std::set + pcg_get_invocation_info_set(ParallelComputationGraph const &); + +ParallelLayerInvocationInfo + pcg_get_invocation_info_for_layer(ParallelComputationGraph const &, + parallel_layer_guid_t); + std::unordered_set get_pcg_edges_from_layer_to_layer(ParallelComputationGraph const &pcg, parallel_layer_guid_t const &src, diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_info.dtg.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_info.dtg.toml new file mode 100644 index 0000000000..67107cdbb1 --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_info.dtg.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "ParallelLayerInfo" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h", +] + +src_includes = [ +] + +[[fields]] +name = "guid" +type = "::FlexFlow::parallel_layer_guid_t" + +[[fields]] +name = "attrs" +type = "::FlexFlow::ParallelLayerAttrs" diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_invocation_info.dtg.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_invocation_info.dtg.toml new file mode 100644 index 0000000000..6555752bcf --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_invocation_info.dtg.toml @@ -0,0 +1,33 @@ +namespace = "FlexFlow" +name = "ParallelLayerInvocationInfo" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "pcg/parallel_computation_graph/parallel_layer_info.dtg.h", + "pcg/parallel_computation_graph/parallel_tensor_info.dtg.h", + "op-attrs/tensor_slot_name.dtg.h", +] + +src_includes = [ + "utils/fmt/map.h", + "utils/hash/map.h", +] + +[[fields]] +name = "incoming" +type = "std::map<::FlexFlow::TensorSlotName, ::FlexFlow::ParallelTensorInfo>" + +[[fields]] +name = "layer_info" +type = "::FlexFlow::ParallelLayerInfo" + +[[fields]] +name = "outgoing" +type = "std::map<::FlexFlow::TensorSlotName, ::FlexFlow::ParallelTensorInfo>" diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_info.dtg.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_info.dtg.toml new file mode 100644 index 0000000000..09f53d6954 --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_info.dtg.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "ParallelTensorInfo" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h", + "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h", +] + +src_includes = [ +] + +[[fields]] +name = "guid" +type = "::FlexFlow::parallel_tensor_guid_t" + +[[fields]] +name = "attrs" +type = "::FlexFlow::ParallelTensorAttrs" diff --git a/lib/pcg/src/pcg/computation_graph.cc b/lib/pcg/src/pcg/computation_graph.cc index 56bfb98856..35ba0747f0 100644 --- a/lib/pcg/src/pcg/computation_graph.cc +++ b/lib/pcg/src/pcg/computation_graph.cc @@ -196,7 +196,7 @@ static std::unordered_map ASSERT(incoming_tensors.size() == incoming_slot_roles.size()); std::unordered_set slots_with_desired_role = - keys(filter_values(incoming_slot_roles, [&](IncomingTensorRole role) { + unordered_keys(filter_values(incoming_slot_roles, [&](IncomingTensorRole role) { return role == desired_role; })); diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index b687aa11b6..40e72aee9d 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -115,10 +115,10 @@ static void check_incoming_tensor_roles( set_union(input_slots, weight_slots)); std::unordered_map current = binary_merge_disjoint_maps( - generate_map( + generate_unordered_map( input_slots, [](TensorSlotName) { return IncomingTensorRole::INPUT; }), - generate_map(weight_slots, [](TensorSlotName) { + generate_unordered_map(weight_slots, [](TensorSlotName) { return IncomingTensorRole::WEIGHT; })); @@ -134,8 +134,8 @@ std::unordered_map &weight_initializers, std::optional> const &outputs) { - ASSERT(are_disjoint(keys(inputs), keys(weight_initializers))); - check_incoming_tensor_roles(layer, keys(inputs), keys(weight_initializers)); + ASSERT(are_disjoint(unordered_keys(inputs), unordered_keys(weight_initializers))); + check_incoming_tensor_roles(layer, unordered_keys(inputs), unordered_keys(weight_initializers)); std::unordered_map input_shapes = map_values( inputs, [&](tensor_guid_t const &t) { return this->get_shape(t); }); diff --git a/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.cc b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.cc index d0fd3300f5..3bb508681d 100644 --- a/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.cc +++ b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.cc @@ -12,6 +12,14 @@ #include "utils/containers/vector_of.h" #include "utils/hash/tuple.h" #include "utils/nonnegative_int/num_elements.h" +#include "utils/containers/require_all_same1.h" +#include "utils/containers/set_of.h" +#include "utils/containers/keys.h" +#include "utils/containers/contains.h" +#include "utils/bidict/algorithms/right_entries.h" +#include "utils/containers/map_values.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/many_to_one/invert_many_to_one.h" namespace FlexFlow { @@ -23,7 +31,7 @@ MappedOperatorTaskGroup::MappedOperatorTaskGroup( transform(vector_of(shard_bindings.right_values()), [&](OperatorAtomicTaskShardBinding const &s) -> std::unordered_set { - return keys(s.tensor_coords); + return unordered_keys(s.tensor_coords); }); std::unordered_set slot_names = @@ -39,8 +47,6 @@ MappedOperatorTaskGroup::MappedOperatorTaskGroup( return ptensor_space_coord_for_slot_name(signature, slot_name); }); - ASSERT(are_all_distinct(coords_for_key)); - std::vector coord_dims_for_key = transform(coords_for_key, [](ParallelTensorSpaceCoordinate const &c) { return ptensor_coord_num_dims(c); @@ -92,15 +98,27 @@ bidict const & return this->shard_bindings; } -bidict +OneToMany get_tensor_bindings_for_slot_name(MappedOperatorTaskGroup const &task_group, TensorSlotName const &slot_name) { - return transform_values(task_group.get_shard_bindings(), - [&](OperatorAtomicTaskShardBinding const &b) { - return ptensor_space_coord_for_slot_name(b, - slot_name); - }) - .reversed(); + std::set slot_names = get_slot_names_for_task_group(task_group); + ASSERT(contains(slot_names, slot_name)); + + std::unordered_map m = + map_values(task_group.get_shard_bindings().as_unordered_map(), + [&](OperatorAtomicTaskShardBinding const &b) -> ParallelTensorSpaceCoordinate { + return ptensor_space_coord_for_slot_name(b, slot_name); + }); + + return invert_many_to_one(many_to_one_from_unstructured_relation(unordered_set_of(m))); +} + +std::set get_slot_names_for_task_group(MappedOperatorTaskGroup const &g) { + return require_all_same1( + transform(vector_of(right_entries(g.get_shard_bindings())), + [&](OperatorAtomicTaskShardBinding const &shard_bindings) -> std::set { + return keys(shard_bindings.tensor_coords); + })); } nlohmann::json diff --git a/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc index fc1dff504b..1bece17c9a 100644 --- a/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc @@ -10,6 +10,12 @@ #include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/materialize_labelled_kwarg_dataflow_graph_view.h" #include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/rewrite_labelled_kwarg_dataflow_graph_node_labels.h" #include "utils/many_to_one/many_to_one_from_map.h" +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_layer_invocation_info.h" +#include "pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h" +#include "utils/containers/set_of.h" +#include "utils/containers/set_union.h" +#include "utils/containers/keys.h" +#include "utils/containers/require_all_of.h" namespace FlexFlow { @@ -18,6 +24,22 @@ std::unordered_set return get_parallel_layers(pcg_from_mpcg(mpcg)); } +std::set + mpcg_get_invocation_set(MappedParallelComputationGraph const &mpcg) +{ + auto mk_mapped_invocation = [&](ParallelLayerInvocationInfo const &invocation) + -> MappedParallelLayerInvocationInfo + { + MappedOperatorTaskGroup mapping = mpcg_get_mapping_for_layer(mpcg, invocation.layer_info.guid); + + return mapped_parallel_layer_invocation_info_from_pcg_invocation_and_mapping(invocation, mapping); + }; + + ParallelComputationGraph pcg = pcg_from_mpcg(mpcg); + + return transform(pcg_get_invocation_info_set(pcg), mk_mapped_invocation); +} + MappedOperatorTaskGroup mpcg_get_mapping_for_layer(MappedParallelComputationGraph const &mpcg, parallel_layer_guid_t l) { @@ -112,6 +134,22 @@ MappedParallelComputationGraph mapped_pcg_from_pcg_and_mapped_op_task_groups( return mapped_op_task_groups.at(l); }; + auto slot_names_for_layer = [&](parallel_layer_guid_t l) -> std::set { + return set_union(keys(get_incoming_tensors(pcg, l)), keys(get_outgoing_tensors(pcg, l))); + }; + + auto slot_names_for_layer_mapping = [&](parallel_layer_guid_t l) -> std::set { + return get_slot_names_for_task_group(mapping_for_layer(l)); + }; + + require_all_of( + get_parallel_layers(pcg), + [&](parallel_layer_guid_t l) -> void { + std::set for_layer = slot_names_for_layer(l); + std::set for_layer_mapping = slot_names_for_layer_mapping(l); + ASSERT(for_layer == for_layer_mapping); + }); + auto mpcg_layer_attrs_from_pcg_layer_attrs = [&](Node const &node, ParallelLayerAttrs const &pcg_layer_attrs) -> MappedParallelLayerAttrs { diff --git a/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_invocation_info.cc b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_invocation_info.cc new file mode 100644 index 0000000000..bda6afb60c --- /dev/null +++ b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_invocation_info.cc @@ -0,0 +1,22 @@ +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_layer_invocation_info.h" + +namespace FlexFlow { + +MappedParallelLayerInvocationInfo + mapped_parallel_layer_invocation_info_from_pcg_invocation_and_mapping( + ParallelLayerInvocationInfo const &invocation_info, + MappedOperatorTaskGroup const &mapping) +{ + return MappedParallelLayerInvocationInfo{ + /*incoming=*/invocation_info.incoming, + /*layer_info=*/MappedParallelLayerInfo{ + /*guid=*/invocation_info.layer_info.guid, + /*attrs=*/invocation_info.layer_info.attrs, + /*mapping=*/mapping, + }, + /*outgoing=*/invocation_info.outgoing, + }; +} + + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index 5098cadafe..c4a429a820 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -37,6 +37,7 @@ #include "utils/graph/node/node.dtg.h" #include "utils/record_formatter.h" #include +#include "utils/containers/map_from_unordered.h" namespace FlexFlow { @@ -98,7 +99,7 @@ ParallelLayerAddedResult add_parallel_layer( std::unordered_map output_flags = maybe_output_flags.value_or( - generate_map(keys(output_shapes), + generate_unordered_map(unordered_keys(output_shapes), [](TensorSlotName const &) { return CreateGrad::YES; })); std::unordered_map output_attrs = @@ -164,6 +165,46 @@ OperatorTaskSpace get_operator_task_space(ParallelComputationGraph const &pcg, compgraph_op_attrs_from_pcg_op_attrs(op_attrs).value(), input_degrees); } +std::set + pcg_get_invocation_info_set(ParallelComputationGraph const &pcg) +{ + return transform(set_of(get_parallel_layers(pcg)), + [&](parallel_layer_guid_t l) -> ParallelLayerInvocationInfo { + return pcg_get_invocation_info_for_layer(pcg, l); + }); +} + +ParallelLayerInvocationInfo + pcg_get_invocation_info_for_layer(ParallelComputationGraph const &pcg, + parallel_layer_guid_t l) +{ + ParallelLayerAttrs l_attrs = get_parallel_layer_attrs(pcg, l); + + std::map incoming = + map_from_unordered(get_incoming_tensors(pcg, l)); + + std::map outgoing = + map_from_unordered(get_outgoing_tensors(pcg, l)); + + auto get_parallel_tensor_info = [&](parallel_tensor_guid_t t) -> ParallelTensorInfo { + ParallelTensorAttrs t_attrs = get_parallel_tensor_attrs(pcg, t); + + return ParallelTensorInfo{ + /*guid=*/t, + /*attrs=*/t_attrs, + }; + }; + + return ParallelLayerInvocationInfo{ + /*incoming=*/map_values(incoming, get_parallel_tensor_info), + /*layer_info=*/ParallelLayerInfo{ + /*guid=*/l, + /*attrs=*/l_attrs, + }, + /*outgoing=*/map_values(outgoing, get_parallel_tensor_info), + }; +} + std::unordered_set get_edges(ParallelComputationGraph const &pcg) { return transform(get_all_kwarg_dataflow_edges(pcg.raw_graph), @@ -188,9 +229,10 @@ std::unordered_set get_outgoing_edges(ParallelComputationGraph const &pcg, parallel_layer_guid_t const &l) { std::unordered_set> raw_edges = + unordered_set_of( get_outgoing_kwarg_dataflow_edges_for_node(pcg.raw_graph, l.raw_graph_node) - .right_values(); + .right_values()); return transform(raw_edges, [](KwargDataflowEdge const &e) { return ParallelComputationGraphEdge{e}; }); @@ -308,7 +350,7 @@ static std::unordered_map ASSERT(incoming_tensors.size() == incoming_slot_roles.size()); std::unordered_set slots_with_desired_role = - keys(filter_values(incoming_slot_roles, [&](IncomingTensorRole role) { + unordered_keys(filter_values(incoming_slot_roles, [&](IncomingTensorRole role) { return role == desired_role; })); diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index 1d6713dcdb..92334cfde9 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -679,10 +679,10 @@ static void check_incoming_tensor_roles( get_incoming_tensor_roles(layer.op_attrs); std::unordered_map current = binary_merge_disjoint_maps( - generate_map( + generate_unordered_map( input_slots, [](TensorSlotName) { return IncomingTensorRole::INPUT; }), - generate_map(weight_slots, [](TensorSlotName) { + generate_unordered_map(weight_slots, [](TensorSlotName) { return IncomingTensorRole::WEIGHT; })); @@ -698,8 +698,8 @@ std::unordered_map std::unordered_map const &weight_initializers) { - ASSERT(are_disjoint(keys(inputs), keys(weight_initializers))); - check_incoming_tensor_roles(layer, keys(inputs), keys(weight_initializers)); + ASSERT(are_disjoint(unordered_keys(inputs), unordered_keys(weight_initializers))); + check_incoming_tensor_roles(layer, unordered_keys(inputs), unordered_keys(weight_initializers)); std::unordered_map input_shapes = map_values(inputs, [&](parallel_tensor_guid_t const &i) { diff --git a/lib/pcg/test/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc b/lib/pcg/test/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc index 7856d89f27..a53c67c336 100644 --- a/lib/pcg/test/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc +++ b/lib/pcg/test/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc @@ -79,18 +79,24 @@ TEST_SUITE(FF_TEST_SUITE) { MappedOperatorTaskGroup partition_mapping = MappedOperatorTaskGroup{ bidict{ - {machine_coord(0_n), - OperatorAtomicTaskShardBinding{ - { - {TensorSlotName::OUTPUT, ptensor_coord(0_n)}, - }, - }}, - {machine_coord(1_n), - OperatorAtomicTaskShardBinding{ - { - {TensorSlotName::OUTPUT, ptensor_coord(1_n)}, - }, - }}, + { + machine_coord(0_n), + OperatorAtomicTaskShardBinding{ + { + {TensorSlotName::INPUT, ptensor_coord(0_n)}, + {TensorSlotName::OUTPUT, ptensor_coord(0_n)}, + }, + }, + }, + { + machine_coord(1_n), + OperatorAtomicTaskShardBinding{ + { + {TensorSlotName::INPUT, ptensor_coord(0_n)}, + {TensorSlotName::OUTPUT, ptensor_coord(1_n)}, + }, + }, + }, }, }; @@ -116,20 +122,26 @@ TEST_SUITE(FF_TEST_SUITE) { MappedOperatorTaskGroup{ bidict{ - {machine_coord(0_n), - OperatorAtomicTaskShardBinding{ - { - {TensorSlotName::LHS_INPUT, ptensor_coord(0_n)}, - {TensorSlotName::RHS_INPUT, ptensor_coord(0_n)}, - }, - }}, - {machine_coord(1_n), - OperatorAtomicTaskShardBinding{ - { - {TensorSlotName::LHS_INPUT, ptensor_coord(1_n)}, - {TensorSlotName::RHS_INPUT, ptensor_coord(1_n)}, - }, - }}, + { + machine_coord(0_n), + OperatorAtomicTaskShardBinding{ + { + {TensorSlotName::LHS_INPUT, ptensor_coord(0_n)}, + {TensorSlotName::RHS_INPUT, ptensor_coord(0_n)}, + {TensorSlotName::OUTPUT, ptensor_coord(0_n)}, + }, + }, + }, + { + machine_coord(1_n), + OperatorAtomicTaskShardBinding{ + { + {TensorSlotName::LHS_INPUT, ptensor_coord(1_n)}, + {TensorSlotName::RHS_INPUT, ptensor_coord(1_n)}, + {TensorSlotName::OUTPUT, ptensor_coord(1_n)}, + }, + }, + }, }, }}, }; diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index fd314ebaea..44488e70ea 100644 --- a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -5,7 +5,7 @@ #include "pcg/parallel_computation_graph/parallel_layer_attrs.h" #include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h" #include "utils/containers/count.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/containers/get_only.h" #include "utils/containers/items.h" #include "utils/containers/require_only_key.h" @@ -248,7 +248,7 @@ TEST_SUITE(FF_TEST_SUITE) { /*paddingW=*/paddingW); std::unordered_map layers = - generate_map(get_parallel_layers(b.pcg), + generate_unordered_map(get_parallel_layers(b.pcg), [&](parallel_layer_guid_t const &l) { return get_parallel_layer_attrs(b.pcg, l); }); diff --git a/lib/realm-execution/src/realm-execution/instance_allocation.cc b/lib/realm-execution/src/realm-execution/instance_allocation.cc index 4ef2919b10..79eef476df 100644 --- a/lib/realm-execution/src/realm-execution/instance_allocation.cc +++ b/lib/realm-execution/src/realm-execution/instance_allocation.cc @@ -42,7 +42,7 @@ TensorInstanceBacking perform_instance_allocation( RealmContext &ctx) { ASSERT(no_tensors_are_allocated(g)); ASSERT(tensors_are_ready_for_allocation(g)); - for (DynamicValueAttrs const &v : keys(preallocated)) { + for (DynamicValueAttrs const &v : unordered_keys(preallocated)) { ASSERT(v.accessor == std::nullopt); } diff --git a/lib/realm-execution/src/realm-execution/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance.cc index aa67110127..4b068d70be 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance.cc @@ -233,10 +233,10 @@ static Realm::Event spawn_dynamic_node_invocation( // chain reductions sequentially to avoid write races on dst Realm::Event result = precondition; - for (auto const &[p, m] : assert_unwrap(output_grad.mapping)) { + for (auto const &[p, m] : unstructured_relation_from_one_to_many(assert_unwrap(output_grad.mapping))) { DynamicValueAttrs replica_key = output_grad; replica_key.mapping = - bidict{{p, m}}; + OneToMany{{p, {m}}}; replica_key.shard_coord = p; Realm::RegionInstance src_inst = diff --git a/lib/realm-execution/src/realm-execution/tasks/serializer/serializable_tensor_instance_backing.cc b/lib/realm-execution/src/realm-execution/tasks/serializer/serializable_tensor_instance_backing.cc index 79a5176c4f..1d53824c29 100644 --- a/lib/realm-execution/src/realm-execution/tasks/serializer/serializable_tensor_instance_backing.cc +++ b/lib/realm-execution/src/realm-execution/tasks/serializer/serializable_tensor_instance_backing.cc @@ -8,25 +8,29 @@ namespace FlexFlow { SerializableTensorInstanceBacking tensor_instance_backing_to_serializable( TensorInstanceBacking const &backing) { - return SerializableTensorInstanceBacking{/*backing=*/map_keys_and_values( + return SerializableTensorInstanceBacking{ + /*backing=*/map_keys_and_values( backing.backing, dynamic_value_attrs_to_serializable, [](std::pair const &p) { return std::pair{realm_instance_to_serializable(p.first), realm_event_to_serializable(p.second)}; - })}; + }), + }; } TensorInstanceBacking tensor_instance_backing_from_serializable( SerializableTensorInstanceBacking const &backing) { - return TensorInstanceBacking{/*backing=*/map_keys_and_values( + return TensorInstanceBacking{ + /*backing=*/map_keys_and_values( backing.backing, dynamic_value_attrs_from_serializable, [](std::pair const &p) { return std::pair{realm_instance_from_serializable(p.first), realm_event_from_serializable(p.second)}; - })}; + }), + }; } } // namespace FlexFlow diff --git a/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc b/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc index 46d29e2bef..6efbb17eb3 100644 --- a/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc +++ b/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc @@ -197,12 +197,14 @@ MappedParallelComputationGraph { cpu0, OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, {TensorSlotName::OUTPUT, tensor_coord0}, }}, }, { cpu1, OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, {TensorSlotName::OUTPUT, tensor_coord1}, }}, }, diff --git a/lib/runtime/src/parallel_tensor_uses.cc b/lib/runtime/src/parallel_tensor_uses.cc index 444d3a061a..970a9109dc 100644 --- a/lib/runtime/src/parallel_tensor_uses.cc +++ b/lib/runtime/src/parallel_tensor_uses.cc @@ -31,7 +31,7 @@ Op const *ParallelTensorUses::get_owner(ParallelTensor const &tensor) const { } void ParallelTensorUses::remove(Op const &op) { - for (auto const &k : keys(this->uses)) { + for (auto const &k : unordered_keys(this->uses)) { inplace_filter(this->uses.at(k), [&](ParallelTensorUseDescription const &d) { return d.op->op_guid == op.op_guid; diff --git a/lib/runtime/src/tensor_uses.cc b/lib/runtime/src/tensor_uses.cc index ce4672342d..db2cc942d4 100644 --- a/lib/runtime/src/tensor_uses.cc +++ b/lib/runtime/src/tensor_uses.cc @@ -20,7 +20,7 @@ std::vector TensorUses::at(size_t tensor_guid) const { } void TensorUses::remove(Layer const &layer) { - for (auto const &k : keys(this->uses)) { + for (auto const &k : unordered_keys(this->uses)) { inplace_filter(this->uses.at(k), [&](TensorUseDescription const &d) { return d.layer->layer_guid == layer.layer_guid; }); diff --git a/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc b/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc index f2686f7cf7..f3ceda7a06 100644 --- a/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc +++ b/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc @@ -10,7 +10,7 @@ #include "substitutions/sub_parallel_computation_graph_data.h" #include "substitutions/sub_parallel_computation_graph_edge.h" #include "utils/containers/binary_merge_disjoint_maps.h" -#include "utils/containers/keys.h" +#include "utils/containers/unordered_keys.h" #include "utils/containers/restrict_keys.h" #include "utils/containers/set_minus.h" #include "utils/containers/values.h" @@ -50,7 +50,7 @@ SubParallelComputationGraph apply_substitution_from_output_result( require_sub_parallel_computation_graph_data_is_valid(pre_data); std::unordered_set pre_nodes = - keys(pre_data.node_data); + unordered_keys(pre_data.node_data); std::unordered_set matched_nodes = unordered_set_of(values(match.node_assignment)); std::unordered_set post_nodes_from_original_graph = diff --git a/lib/substitutions/src/substitutions/apply_substitution/perform_shape_inference.cc b/lib/substitutions/src/substitutions/apply_substitution/perform_shape_inference.cc index 8e1c06b9b5..9ae007ef16 100644 --- a/lib/substitutions/src/substitutions/apply_substitution/perform_shape_inference.cc +++ b/lib/substitutions/src/substitutions/apply_substitution/perform_shape_inference.cc @@ -56,12 +56,12 @@ LabelledOpenKwargDataflowGraphView incoming_tensor_roles = get_incoming_tensor_roles(n_attrs.op_attrs); - ASSERT(is_subseteq_of(keys(incoming_shapes), keys(incoming_tensor_roles))); + ASSERT(is_subseteq_of(unordered_keys(incoming_shapes), unordered_keys(incoming_tensor_roles))); auto incoming_shapes_with_role = [&](IncomingTensorRole role) -> std::unordered_map { std::unordered_set slots_with_desired_role = - keys(filter_values(incoming_tensor_roles, + unordered_keys(filter_values(incoming_tensor_roles, [&](IncomingTensorRole r) { return r == role; })); return restrict_keys(incoming_shapes, slots_with_desired_role); diff --git a/lib/substitutions/src/substitutions/pcg_pattern_match.cc b/lib/substitutions/src/substitutions/pcg_pattern_match.cc index 85a0493e33..8a71fe2ad5 100644 --- a/lib/substitutions/src/substitutions/pcg_pattern_match.cc +++ b/lib/substitutions/src/substitutions/pcg_pattern_match.cc @@ -76,7 +76,7 @@ void assert_pcg_pattern_match_is_valid_for_pattern_and_subpcg( std::unordered_set pattern_inputs = get_inputs(pattern); std::unordered_set match_pattern_inputs = - keys(match.input_assignment); + unordered_keys(match.input_assignment); ASSERT(pattern_inputs == match_pattern_inputs); } diff --git a/lib/substitutions/src/substitutions/substitution_builder.cc b/lib/substitutions/src/substitutions/substitution_builder.cc index f2860326ab..ffda2291b5 100644 --- a/lib/substitutions/src/substitutions/substitution_builder.cc +++ b/lib/substitutions/src/substitutions/substitution_builder.cc @@ -94,7 +94,7 @@ std::unordered_map node_expr, map_values(inputs, raw_open_kwarg_dataflow_value_from_output_graph_expr_value), - generate_map(output_slots, + generate_unordered_map(output_slots, [](TensorSlotName) { return std::monostate{}; })); return map_values( diff --git a/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc b/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc index e9087b5718..a982277d22 100644 --- a/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc +++ b/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc @@ -16,9 +16,6 @@ #include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_incoming_open_kwarg_dataflow_values_for_node.h" #include "utils/many_to_one/invert_many_to_one.h" #include "utils/many_to_one/many_to_one_from_map.h" -#include "utils/many_to_one/many_to_one_from_unstructured_relation.h" -#include "utils/many_to_one/unstructured_relation_from_many_to_one.h" -#include "utils/one_to_many/unstructured_relation_from_one_to_many.h" #include "utils/overload.h" namespace FlexFlow { @@ -46,7 +43,7 @@ static std::optional return OpenKwargDataflowValue{o}; }); - if (keys(pattern_outputs) != keys(graph_outputs)) { + if (unordered_keys(pattern_outputs) != unordered_keys(graph_outputs)) { return std::nullopt; } @@ -64,7 +61,7 @@ static std::optional graph_node_inputs = get_incoming_open_kwarg_dataflow_values_for_node(graph, graph_node); - if (keys(graph_node_inputs) != keys(pattern_node_inputs)) { + if (unordered_keys(graph_node_inputs) != unordered_keys(pattern_node_inputs)) { return std::nullopt; } diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc index 703d651070..25d505b1fa 100644 --- a/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc @@ -190,7 +190,7 @@ bool unlabelled_pattern_does_match( ASSERT(left_entries(match.node_assignment) == get_pattern_nodes(pattern)); ASSERT( is_subseteq_of(right_entries(match.node_assignment), get_nodes(graph))); - ASSERT(keys(match.input_assignment) == get_pattern_inputs(pattern)); + ASSERT(unordered_keys(match.input_assignment) == get_pattern_inputs(pattern)); ASSERT(is_subseteq_of(matched_by_pattern_inputs, get_all_open_kwarg_dataflow_values(graph))); diff --git a/lib/task-spec/include/task-spec/dynamic_graph/copy_insertion.h b/lib/task-spec/include/task-spec/dynamic_graph/copy_insertion.h index a1726c2ae1..7a383ee8eb 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/copy_insertion.h +++ b/lib/task-spec/include/task-spec/dynamic_graph/copy_insertion.h @@ -13,6 +13,11 @@ bool value_is_mapped(DynamicValueAttrs const &); bool no_part_of_graph_is_copy_inserted(DynamicOpenDataflowGraph const &); bool graph_is_fully_copy_inserted(DynamicOpenDataflowGraph const &); +std::unordered_set copies_for_invocation_inputs( + DynamicNodeInvocation const &i, + std::unordered_map const + &unmapped_value_to_mapped_source_value); + std::unordered_set perform_copy_insertion_for_invocation( DynamicNodeInvocation const &i, std::unordered_map const diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_invocation.h b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_invocation.h new file mode 100644 index 0000000000..a87f2241c4 --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_invocation.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_DYNAMIC_NODE_INVOCATION_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_DYNAMIC_NODE_INVOCATION_H + +#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" + +namespace FlexFlow { + +bool invocation_fully_satisfies(DynamicNodeInvocation const &, + std::function const &node_condition, + std::function const &value_condition, + std::function const &slot_condition); + +void require_invocation_fully_satisfies(DynamicNodeInvocation const &, + std::function const &require_node_condition, + std::function const &require_value_condition, + std::function const &require_slot_condition); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_invocation_sharding_info.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_invocation_sharding_info.dtg.toml new file mode 100644 index 0000000000..a59aba92d7 --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_invocation_sharding_info.dtg.toml @@ -0,0 +1,30 @@ +namespace = "FlexFlow" +name = "DynamicNodeInvocationShardingInfo" +type = "struct" +#include "task-spec/dynamic_graph/shard_expansion.h" +features = [ + "eq", + "ord", + "hash", + "fmt", + "json", +] + +includes = [ + "pcg/machine_space_coordinate.dtg.h", + "task-spec/dynamic_graph/dynamic_tensor_slot.dtg.h", + "task-spec/dynamic_graph/dynamic_value_attrs_sharding_info.dtg.h", +] + +src_includes = [ + "utils/hash/map.h", + "utils/fmt/map.h", +] + +[[fields]] +name = "device_coord" +type = "::FlexFlow::MachineSpaceCoordinate" + +[[fields]] +name = "value_sharding" +type = "std::map<::FlexFlow::DynamicTensorSlot, ::FlexFlow::DynamicValueAttrsShardingInfo>" diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_open_dataflow_graph.h b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_open_dataflow_graph.h index 4ca62db5b1..1aba00a675 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_open_dataflow_graph.h +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_open_dataflow_graph.h @@ -23,6 +23,12 @@ bool no_part_of_dynamic_graph_satisfies( std::function const &, std::function const &); +void require_full_dynamic_graph_satisfies( + DynamicOpenDataflowGraph const &, + std::function const &, + std::function const &, + std::function const &); + std::unordered_multiset get_dynamic_nodes(DynamicOpenDataflowGraph const &); std::unordered_multiset diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.dtg.toml index 490a51f88d..add72764f1 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.dtg.toml @@ -13,7 +13,7 @@ includes = [ "op-attrs/parallel_tensor_shape.dtg.h", "op-attrs/parallel_tensor_space_coordinate.dtg.h", "pcg/machine_space_coordinate.dtg.h", - "utils/bidict/bidict.h", + "utils/one_to_many/one_to_many.h", "task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.h", "task-spec/dynamic_graph/dynamic_tensor_role.dtg.h", ] @@ -25,18 +25,38 @@ src_includes = [ [[fields]] name = "tensor_guid" type = "::FlexFlow::dynamic_tensor_guid_t" +docstring = ''' +\brief The \ref tensor_guid_t or \ref parallel_tensor_guid_t of the (usually parallel) tensor this value originates from. Also allows representing tensors for computing the loss that lie outside of the scope of the \ref ComputationGraph or \ref ParallelComputationGraph, e.g., the label tensor. + +For a \ref DynamicOpenDataflowGraph originating from a \ref MapepdParallelComputationGraph, this field is filled in by \ref make_dynamic_open_dataflow_graph_from_mapped_pcg.h. +''' [[fields]] name = "parallel_tensor_shape" type = "std::optional<::FlexFlow::ParallelTensorShape>" +docstring = ''' +\brief The \ref ParallelTensorShape of the parallel tensor this value originates from. + +For a \ref DynamicOpenDataflowGraph originating form a \ref MappedParallelComputationGraph, this field is filled in by \ref make_dynamic_open_dataflow_graph_from_mapped_pcg.h. +''' [[fields]] name = "shard_coord" type = "std::optional<::FlexFlow::ParallelTensorSpaceCoordinate>" +docstring = ''' +\brief The shard (i.e., \ref ParallelTensorSpaceCoordinate) of the (usually parallel) tensor represented by this value. + +For a \ref DynamicOpenDataflowGraph originating from a \ref MappedParallelComputationGraph, this field is filled in by \ref shard_expansion.h. +''' [[fields]] name = "mapping" -type = "std::optional<::FlexFlow::bidict<::FlexFlow::ParallelTensorSpaceCoordinate, ::FlexFlow::MachineSpaceCoordinate>>" +type = "std::optional<::FlexFlow::OneToMany<::FlexFlow::ParallelTensorSpaceCoordinate, ::FlexFlow::MachineSpaceCoordinate>>" +docstring = ''' +\brief The location (i.e., \ref MachineSpaceCoordinate) of each shard (i.e., \ref ParallelTensorSpaceCoordinate) of this (usually parallel) tensor. + +For a \ref DynamicOpenDataflowGraph originating from a \ref MappedParallelComputationGraph, this field is filled in by \ref shard_expansion.h. +''' [[fields]] name = "accessor" @@ -45,3 +65,8 @@ type = "std::optional<::FlexFlow::DynamicTensorAccessor>" [[fields]] name = "role" type = "std::optional<::FlexFlow::DynamicTensorRole>" +docstring = ''' +\brief Identifies the role this tensor plays in training, e.g., a forward tensor, a gradient tensor, an optimizer buffer, a loss tensor, etc. + +For a \ref DynamicOpenDataflowGraph originating from a \ref MappedParallelComputationGraph, this field is filled in by \ref pass_expansion.h. +''' diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.h b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.h index 9cccc565cc..aa9fbd2874 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.h +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.h @@ -8,6 +8,10 @@ namespace FlexFlow { DynamicValueAttrs decide_dynamic_value_attrs_role(DynamicValueAttrs const &, DynamicTensorRole); +DynamicValueAttrs decide_dynamic_value_attrs_mapping( + DynamicValueAttrs const &, + OneToMany const &); + } // namespace FlexFlow #endif diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs_sharding_info.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs_sharding_info.dtg.toml new file mode 100644 index 0000000000..5a9234d815 --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs_sharding_info.dtg.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "DynamicValueAttrsShardingInfo" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", + "json", +] + +includes = [ + "utils/one_to_many/one_to_many.h", + "op-attrs/parallel_tensor_space_coordinate.dtg.h", + "pcg/machine_space_coordinate.dtg.h", +] + +src_includes = [ +] + +[[fields]] +name = "shard_coord" +type = "::FlexFlow::ParallelTensorSpaceCoordinate" + +[[fields]] +name = "mapping" +type = "::FlexFlow::OneToMany<::FlexFlow::ParallelTensorSpaceCoordinate, ::FlexFlow::MachineSpaceCoordinate>" diff --git a/lib/task-spec/include/task-spec/dynamic_graph/index.dox b/lib/task-spec/include/task-spec/dynamic_graph/index.dox index c48e67f4b3..97b72f8553 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/index.dox +++ b/lib/task-spec/include/task-spec/dynamic_graph/index.dox @@ -7,6 +7,7 @@ namespace FlexFlow { \section task-spec-lowering-passes Lowering Passes +- \ref make_dynamic_open_dataflow_graph_from_mapped_pcg.h: Embeds a \ref MappedParallelComputationGraph as a \ref DynamicOpenDataflowGraph. The first of the lowering passes. - \ref pass_expansion.h - \ref shard_expansion.h - \ref update_insertion.h diff --git a/lib/task-spec/include/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.h b/lib/task-spec/include/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.h index 6a269ec3c9..f1d693c975 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.h +++ b/lib/task-spec/include/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.h @@ -3,9 +3,15 @@ #include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.dtg.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_layer_invocation_info.dtg.h" namespace FlexFlow { +DynamicNodeInvocation make_dynamic_node_invocation_from_mapped( + MappedParallelLayerInvocationInfo const &); + +DynamicNodeInvocation build_replicate_invocation(MappedParallelLayerInvocationInfo const &); + DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mapped_pcg( MappedParallelComputationGraph const &); diff --git a/lib/task-spec/include/task-spec/dynamic_graph/pass_expansion.h b/lib/task-spec/include/task-spec/dynamic_graph/pass_expansion.h index 6dce8ad514..ad07b2941f 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/pass_expansion.h +++ b/lib/task-spec/include/task-spec/dynamic_graph/pass_expansion.h @@ -10,8 +10,13 @@ bool node_is_pass_expanded(DynamicNodeAttrs const &); bool value_is_pass_expanded(DynamicValueAttrs const &); bool slot_is_pass_expanded(DynamicTensorSlot const &); +bool node_is_ready_for_pass_expansion(DynamicNodeAttrs const &); +bool value_is_ready_for_pass_expansion(DynamicValueAttrs const &); +bool slot_is_ready_for_pass_expansion(DynamicTensorSlot const &); + bool no_part_of_graph_is_pass_expanded(DynamicOpenDataflowGraph const &); bool graph_is_fully_pass_expanded(DynamicOpenDataflowGraph const &); +bool graph_is_ready_for_pass_expansion(DynamicOpenDataflowGraph const &); DynamicNodeInvocation perform_fwd_pass_expansion_for_invocation(DynamicNodeInvocation const &); diff --git a/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_value_attrs.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_value_attrs.dtg.toml index 454f1b7e8c..d3cab6ecdb 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_value_attrs.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_value_attrs.dtg.toml @@ -14,7 +14,7 @@ includes = [ "op-attrs/parallel_tensor_shape.dtg.h", "op-attrs/parallel_tensor_space_coordinate.dtg.h", "pcg/machine_space_coordinate.dtg.h", - "utils/bidict/bidict.h", + "utils/one_to_many/one_to_many.h", "task-spec/dynamic_graph/dynamic_tensor_role.dtg.h", ] @@ -37,7 +37,7 @@ type = "std::optional<::FlexFlow::ParallelTensorSpaceCoordinate>" [[fields]] name = "mapping" -type = "std::optional<::FlexFlow::bidict<::FlexFlow::ParallelTensorSpaceCoordinate, ::FlexFlow::MachineSpaceCoordinate>>" +type = "std::optional<::FlexFlow::OneToMany<::FlexFlow::ParallelTensorSpaceCoordinate, ::FlexFlow::MachineSpaceCoordinate>>" [[fields]] name = "role" diff --git a/lib/task-spec/include/task-spec/dynamic_graph/shard_expansion.h b/lib/task-spec/include/task-spec/dynamic_graph/shard_expansion.h index 4e0db1cd7e..50c713ee3b 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/shard_expansion.h +++ b/lib/task-spec/include/task-spec/dynamic_graph/shard_expansion.h @@ -4,19 +4,42 @@ #include "task-spec/dynamic_graph/dynamic_node_attrs.dtg.h" #include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" +#include "task-spec/dynamic_graph/dynamic_value_attrs_sharding_info.dtg.h" +#include "task-spec/dynamic_graph/dynamic_node_invocation_sharding_info.dtg.h" namespace FlexFlow { -bool node_is_shard_expanded(DynamicNodeAttrs const &); -bool value_is_shard_expanded(DynamicValueAttrs const &); +[[nodiscard]] bool node_is_shard_expanded(DynamicNodeAttrs const &); +[[nodiscard]] bool value_is_shard_expanded(DynamicValueAttrs const &); +[[nodiscard]] bool invocation_is_fully_shard_expanded(DynamicNodeInvocation const &); -bool no_part_of_graph_is_shard_expanded(DynamicOpenDataflowGraph const &); -bool graph_is_fully_shard_expanded(DynamicOpenDataflowGraph const &); +[[nodiscard]] bool node_is_ready_for_shard_expansion(DynamicNodeAttrs const &); +[[nodiscard]] bool value_is_ready_for_shard_expansion(DynamicValueAttrs const &); +[[nodiscard]] bool invocation_is_ready_for_shard_expansion(DynamicNodeInvocation const &); -std::unordered_set +[[nodiscard]] bool no_part_of_graph_is_shard_expanded(DynamicOpenDataflowGraph const &); +[[nodiscard]] bool graph_is_fully_shard_expanded(DynamicOpenDataflowGraph const &); +[[nodiscard]] bool graph_is_ready_for_shard_expansion(DynamicOpenDataflowGraph const &); + +[[nodiscard]] DynamicNodeAttrs apply_dynamic_node_attrs_sharding_info( + DynamicNodeAttrs const &, + MachineSpaceCoordinate const &); + +[[nodiscard]] DynamicValueAttrs apply_dynamic_value_attrs_sharding_info( + DynamicValueAttrs const &, + DynamicValueAttrsShardingInfo const &); + +[[nodiscard]] DynamicNodeInvocation apply_dynamic_node_invocation_sharding_info( + DynamicNodeInvocation const &, + DynamicNodeInvocationShardingInfo const &); + +[[nodiscard]] std::unordered_set + generate_shard_expansion_for_invocation(DynamicNodeInvocation const &); + +[[nodiscard]] std::unordered_set perform_shard_expansion_for_invocation(DynamicNodeInvocation const &); -DynamicOpenDataflowGraph +[[nodiscard]] DynamicOpenDataflowGraph perform_shard_expansion(DynamicOpenDataflowGraph const &); } // namespace FlexFlow diff --git a/lib/task-spec/include/task-spec/dynamic_graph/update_insertion.h b/lib/task-spec/include/task-spec/dynamic_graph/update_insertion.h index 23fb7050a0..9818152b34 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/update_insertion.h +++ b/lib/task-spec/include/task-spec/dynamic_graph/update_insertion.h @@ -6,6 +6,15 @@ namespace FlexFlow { +bool node_has_already_had_update_insertion_performed(DynamicNodeAttrs const &); +bool value_has_already_had_update_insertion_performed(DynamicValueAttrs const &); + +bool node_is_ready_for_update_insertion(DynamicNodeAttrs const &); +bool value_is_ready_for_update_insertion(DynamicValueAttrs const &); + +bool no_part_of_graph_has_had_update_insertion_performed(DynamicOpenDataflowGraph const &); +bool graph_is_ready_for_update_insertion(DynamicOpenDataflowGraph const &); + std::unordered_set perform_update_insertion_for_invocation(DynamicNodeInvocation const &, OptimizerAttrs const &); diff --git a/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc b/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc index ef41042a51..08ab3b11aa 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc @@ -10,7 +10,7 @@ #include "task-spec/dynamic_graph/dynamic_tensor_slot.dtg.h" #include "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h" #include "utils/bidict/algorithms/bidict_from_pairs.h" -#include "utils/bidict/algorithms/unordered_set_of.h" +#include "utils/bidict/algorithms/bidict_unordered_set_of.h" #include "utils/containers/contains_key.h" #include "utils/containers/flatmap.h" #include "utils/containers/intersection.h" @@ -25,25 +25,13 @@ bool node_is_copy(DynamicNodeAttrs const &n) { return n.op_attrs.has_value() && n.op_attrs.value().is_copy(); } -static bool is_replicate_invocation(DynamicNodeInvocation const &i) { - return i.node_attrs.op_attrs.has_value() && - i.node_attrs.op_attrs.value().has() && - i.node_attrs.op_attrs.value() - .get() - .has(); -} - bool value_is_mapped(DynamicValueAttrs const &n) { return n.mapping.has_value(); } bool no_part_of_graph_is_copy_inserted(DynamicOpenDataflowGraph const &g) { auto slot_is_mapped = [](DynamicTensorSlot const &) -> bool { return false; }; - // check all non-replicate invocations for (DynamicNodeInvocation const &i : g.invocations) { - if (is_replicate_invocation(i)) { - continue; // replicate tensors have mapping set by design - } if (node_is_copy(i.node_attrs)) { return false; } @@ -69,6 +57,26 @@ bool graph_is_fully_copy_inserted(DynamicOpenDataflowGraph const &g) { g, node_is_any, value_is_mapped, slot_is_mapped); } +void require_node_is_ready_for_copy_insertion(DynamicNodeAttrs const &node_attrs) { + ASSERT(node_attrs.mapping.has_value()); +} + +void require_graph_is_ready_for_copy_insertion(DynamicOpenDataflowGraph const &g) { + auto require_slot_is_ready_for_copy_insertion = [](DynamicTensorSlot const &slot) -> void { + return; + }; + + auto require_value_is_ready_for_copy_insertion = [](DynamicValueAttrs const &value_attrs) -> void { + return; + }; + + require_full_dynamic_graph_satisfies( + g, + require_node_is_ready_for_copy_insertion, + require_value_is_ready_for_copy_insertion, + require_slot_is_ready_for_copy_insertion); +} + static DynamicValueAttrs map_dynamic_value_attrs_for_task_group( DynamicTensorSlot const &slot, DynamicValueAttrs const &value, @@ -83,10 +91,10 @@ static std::pair DynamicValueAttrs const &output) { std::unordered_set< std::pair> - input_mapping = unordered_set_of(assert_unwrap(input.mapping)); + input_mapping = unstructured_relation_from_one_to_many(assert_unwrap(input.mapping)); std::unordered_set< std::pair> - output_mapping = unordered_set_of(assert_unwrap(output.mapping)); + output_mapping = unstructured_relation_from_one_to_many(assert_unwrap(output.mapping)); // Exclude the point shared between the input and output mappings, because // those will not result in actual copies once shard expansion is performed @@ -96,25 +104,19 @@ static std::pair DynamicValueAttrs filtered_input = input; filtered_input.mapping = - bidict_from_pairs(set_difference(input_mapping, remove)); + one_to_many_from_unstructured_relation(set_difference(input_mapping, remove)); DynamicValueAttrs filtered_output = output; filtered_output.mapping = - bidict_from_pairs(set_difference(output_mapping, remove)); + one_to_many_from_unstructured_relation(set_difference(output_mapping, remove)); return std::pair{filtered_input, filtered_output}; } -std::unordered_set perform_copy_insertion_for_invocation( - DynamicNodeInvocation const &i, - std::unordered_map const - &unmapped_value_to_mapped_source_value) { - - // replicate nodes have no MappedOperatorTaskGroup — - // pass through unchanged, no copies needed - if (is_replicate_invocation(i)) { - return {i}; - } +std::unordered_set copies_for_invocation_inputs( + DynamicNodeInvocation const &i, + std::unordered_map const &unmapped_value_to_src_mapped_value) +{ MappedOperatorTaskGroup mapping = assert_unwrap(i.node_attrs.mapping); auto map_tensor = [&](DynamicTensorSlot const &slot, @@ -124,31 +126,26 @@ std::unordered_set perform_copy_insertion_for_invocation( std::unordered_map mapped_inputs = map_values2(i.inputs, map_tensor); - std::unordered_map mapped_outputs = - map_values2(i.outputs, map_tensor); - std::unordered_set result{DynamicNodeInvocation{ - /*inputs=*/mapped_inputs, - /*node_attrs=*/i.node_attrs, - /*outputs=*/mapped_outputs, - }}; + std::unordered_set result; for (auto const &[slot, input] : i.inputs) { - if (!contains_key(unmapped_value_to_mapped_source_value, input)) { + if (!contains_key(unmapped_value_to_src_mapped_value, input)) { continue; } - DynamicValueAttrs source_value = - unmapped_value_to_mapped_source_value.at(input); - DynamicValueAttrs use_value = mapped_inputs.at(slot); - if (source_value != use_value) { - auto const &[filtered_source, filtered_use] = - filter_mapping_to_avoid_degenerate_copies(source_value, use_value); + DynamicValueAttrs src_mapped_value = unmapped_value_to_src_mapped_value.at(input); + DynamicValueAttrs use_mapped_value = mapped_inputs.at(slot); + + if (src_mapped_value != use_mapped_value) { + auto const &[filtered_source, filtered_use] = filter_mapping_to_avoid_degenerate_copies(src_mapped_value, use_mapped_value); DynamicNodeInvocation copy{ /*inputs=*/{ { - DynamicTensorSlot{TensorSlotName::INPUT, - slot.slot_tensor_role}, + DynamicTensorSlot{ + TensorSlotName::INPUT, + slot.slot_tensor_role, + }, filtered_source, }, }, @@ -179,22 +176,48 @@ std::unordered_set perform_copy_insertion_for_invocation( return result; } +std::unordered_set perform_copy_insertion_for_invocation( + DynamicNodeInvocation const &i, + std::unordered_map const + &unmapped_value_to_mapped_source_value) { + + MappedOperatorTaskGroup mapping = assert_unwrap(i.node_attrs.mapping); + + auto map_tensor = [&](DynamicTensorSlot const &slot, + DynamicValueAttrs const &value) { + return map_dynamic_value_attrs_for_task_group(slot, value, mapping); + }; + + DynamicNodeInvocation mapped_i = [&] { + std::unordered_map mapped_inputs = + map_values2(i.inputs, map_tensor); + std::unordered_map mapped_outputs = + map_values2(i.outputs, map_tensor); + + DynamicNodeInvocation r = i; + r.inputs = mapped_inputs; + r.outputs = mapped_outputs; + return r; + }(); + + std::unordered_set result = set_union( + copies_for_invocation_inputs(i, unmapped_value_to_mapped_source_value), + std::unordered_set{ + mapped_i, + }); + + return result; +} + DynamicOpenDataflowGraph perform_copy_insertion(DynamicOpenDataflowGraph const &g) { ASSERT(no_part_of_graph_is_copy_inserted(g)); + require_graph_is_ready_for_copy_insertion(g); std::unordered_map unmapped_value_to_mapped_source_value; for (DynamicNodeInvocation const &i : g.invocations) { - // replicate nodes have no MappedOperatorTaskGroup — - // output mapping already fully set, maps to itself - if (is_replicate_invocation(i)) { - for (auto const &[slot, value] : i.outputs) { - unmapped_value_to_mapped_source_value.insert(std::pair{value, value}); - } - continue; - } for (auto const &[slot, value] : i.outputs) { unmapped_value_to_mapped_source_value.insert( std::pair{value, diff --git a/lib/task-spec/src/task-spec/dynamic_graph/dynamic_node_invocation.cc b/lib/task-spec/src/task-spec/dynamic_graph/dynamic_node_invocation.cc new file mode 100644 index 0000000000..ea50449347 --- /dev/null +++ b/lib/task-spec/src/task-spec/dynamic_graph/dynamic_node_invocation.cc @@ -0,0 +1,35 @@ +#include "task-spec/dynamic_graph/dynamic_node_invocation.h" +#include "utils/containers/values.h" +#include "utils/containers/keys.h" +#include "utils/containers/all_of.h" + +namespace FlexFlow { + +bool invocation_fully_satisfies(DynamicNodeInvocation const &i, + std::function const &node_condition, + std::function const &value_condition, + std::function const &slot_condition) +{ + return node_condition(i.node_attrs) + && all_of(values(i.inputs), value_condition) + && all_of(keys(i.inputs), slot_condition) + && all_of(values(i.outputs), value_condition) + && all_of(keys(i.outputs), slot_condition); +} + +void require_invocation_fully_satisfies(DynamicNodeInvocation const &i, + std::function const &require_node_condition, + std::function const &require_value_condition, + std::function const &require_slot_condition) { + require_node_condition(i.node_attrs); + for (DynamicTensorSlot const &k : keys(i.inputs)) { + require_slot_condition(k); + require_value_condition(i.inputs.at(k)); + } + for (DynamicTensorSlot const &k : keys(i.outputs)) { + require_slot_condition(k); + require_value_condition(i.outputs.at(k)); + } +} + +} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc b/lib/task-spec/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc index d2a5b653e5..a100c3adfb 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc @@ -19,6 +19,7 @@ #include "utils/graph/open_dataflow_graph/algorithms/get_inputs.h" #include "utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_graph_input.dtg.h" #include "utils/many_to_one/many_to_one.h" +#include "utils/containers/require_all_of.h" namespace FlexFlow { @@ -56,6 +57,18 @@ bool no_part_of_dynamic_graph_satisfies( [&](DynamicTensorSlot const &s) -> bool { return !slot_condition(s); }); } +void require_full_dynamic_graph_satisfies( + DynamicOpenDataflowGraph const &g, + std::function const &node_condition, + std::function const &value_condition, + std::function const &slot_condition) +{ + require_all_of(get_dynamic_nodes(g), node_condition); + require_all_of(get_dynamic_values(g), value_condition); + require_all_of(get_dynamic_tensor_slots(g), slot_condition); +} + + std::unordered_multiset get_dynamic_nodes(DynamicOpenDataflowGraph const &g) { return transform(unordered_multiset_of(g.invocations), diff --git a/lib/task-spec/src/task-spec/dynamic_graph/dynamic_value_attrs.cc b/lib/task-spec/src/task-spec/dynamic_graph/dynamic_value_attrs.cc index 282279edbe..9a70c5cdd0 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/dynamic_value_attrs.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/dynamic_value_attrs.cc @@ -13,4 +13,17 @@ DynamicValueAttrs return result; } +DynamicValueAttrs decide_dynamic_value_attrs_mapping( + DynamicValueAttrs const &attrs, + OneToMany const &mapping) +{ + ASSERT(!attrs.mapping.has_value()); + + DynamicValueAttrs result = attrs; + result.mapping = mapping; + + return result; +} + + } // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_cg.cc b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_cg.cc index 6bfc477e3a..7fe3927fd1 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_cg.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_cg.cc @@ -7,7 +7,7 @@ #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_role.h" #include "task-spec/dynamic_graph/training_operation_attrs.dtg.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include #include #include diff --git a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc index 7a149787b9..391ebaff3b 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc @@ -10,7 +10,6 @@ #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_role.h" #include "utils/bidict/algorithms/merge_disjoint_bidicts.h" -#include "utils/containers/generate_map.h" #include "utils/containers/get_only.h" #include "utils/containers/map_keys_and_values.h" #include "utils/containers/require_only_key.h" @@ -18,194 +17,64 @@ #include #include #include +#include "utils/containers/unordered_map_from_map.h" +#include "utils/bidict/algorithms/bidict_unordered_set_of.h" namespace FlexFlow { -static bidict - get_input_mapping_for_replicate( - MappedParallelComputationGraph const &mpcg, - parallel_layer_guid_t const &replicate_layer) { - - ASSERT(mpcg_get_pcg_op_attrs(mpcg, replicate_layer).is_parallel_replicate()); - - auto [input_slot_name, input_edge] = - get_only(mpcg_get_incoming_edges(mpcg, replicate_layer)); - - parallel_layer_guid_t producer_layer = get_src_layer(input_edge); - TensorSlotName producer_slot = get_src_layer_output_slot_name(input_edge); - - return get_tensor_bindings_for_slot_name( - /*task_group=*/mpcg_get_mapping_for_layer(mpcg, producer_layer), - /*slot_name=*/producer_slot); -} - -static bidict - build_replicated_output_mapping( - MappedParallelComputationGraph const &mpcg, - parallel_tensor_guid_t const &output_tensor_guid) { - - std::unordered_set consumers = - mpcg_get_parallel_tensor_uses(mpcg, output_tensor_guid); - ASSERT(!consumers.empty()); - - // union all consumer bindings — each consumer shard maps to a distinct - // (discard_copy, machine) pair since replicas are always on different machines - bidict result = - merge_disjoint_bidicts(transform( - consumers, - [&](parallel_tensor_use_t const &use) - -> bidict { - parallel_layer_guid_t consumer_layer = - parallel_tensor_use_get_layer(use); - TensorSlotName slot_name = parallel_tensor_use_get_slot(use); - - MappedOperatorTaskGroup consumer_mapping = - mpcg_get_mapping_for_layer(mpcg, consumer_layer); - bidict - binding = get_tensor_bindings_for_slot_name(consumer_mapping, - slot_name); - - return binding; - })); - - return result; -} - -static DynamicNodeInvocation - build_replicate_invocation(parallel_layer_guid_t const &layer, - ReplicateAttrs const &attrs, - MappedParallelComputationGraph const &mpcg) { - - ManyToOne incoming = - mpcg_get_incoming_tensors(mpcg, layer); - TensorSlotName input_slot_name = TensorSlotName::INPUT; - parallel_tensor_guid_t input_tensor_guid = - require_only_key(incoming.l_to_r(), input_slot_name); - ParallelTensorAttrs input_attrs = - mpcg_get_parallel_tensor_attrs(mpcg, input_tensor_guid); - - bidict outgoing = - mpcg_get_outgoing_tensors(mpcg, layer); - TensorSlotName output_slot_name = TensorSlotName::OUTPUT; - parallel_tensor_guid_t output_tensor_guid = - require_only_key(outgoing.l_to_r(), output_slot_name); - ParallelTensorAttrs output_attrs = - mpcg_get_parallel_tensor_attrs(mpcg, output_tensor_guid); - - bidict input_mapping = - get_input_mapping_for_replicate(mpcg, layer); - - DynamicValueAttrs input_value{ - /*tensor_guid=*/dynamic_tensor_guid_t{input_tensor_guid}, - /*parallel_tensor_shape=*/input_attrs.shape, - /*shard_coord=*/std::nullopt, - /*mapping=*/input_mapping, - /*accessor=*/std::nullopt, - /*role=*/std::nullopt, - }; - - DynamicValueAttrs output_value{ - /*tensor_guid=*/dynamic_tensor_guid_t{output_tensor_guid}, - /*parallel_tensor_shape=*/output_attrs.shape, - /*shard_coord=*/std::nullopt, - /*mapping=*/build_replicated_output_mapping(mpcg, output_tensor_guid), - /*accessor=*/std::nullopt, - /*role=*/std::nullopt, - }; - - DynamicNodeAttrs node_attrs{ +DynamicNodeInvocation make_dynamic_node_invocation_from_mapped( + MappedParallelLayerInvocationInfo const &invocation_info) +{ + DynamicNodeAttrs result_attrs{ /*task_type=*/std::nullopt, /*device_coord=*/std::nullopt, - /*mapping=*/std::nullopt, - /*op_attrs=*/TrainingOperationAttrs{PCGOperatorAttrs{attrs}}, - /*pcg_layer_guid=*/dynamic_layer_guid_t{layer}, + /*mapping=*/invocation_info.layer_info.mapping, + /*op_attrs=*/TrainingOperationAttrs{invocation_info.layer_info.attrs.op_attrs}, + /*pcg_layer_guid=*/dynamic_layer_guid_t{invocation_info.layer_info.guid}, /*per_device_op_state=*/std::nullopt, }; - DynamicNodeInvocation invocation_node{ - /*inputs=*/{ - { - DynamicTensorSlot{input_slot_name, std::nullopt}, - input_value, - }, + auto lift_kv_pair = + [&](TensorSlotName slot_name, + ParallelTensorInfo const &tensor) + -> std::pair + { + return { + DynamicTensorSlot{ + /*slot_name=*/slot_name, + /*slot_tensor_role=*/std::nullopt, }, - /*node_attrs=*/node_attrs, - /*outputs=*/ - { - { - DynamicTensorSlot{output_slot_name, std::nullopt}, - output_value, - }, + DynamicValueAttrs{ + /*tensor_guid=*/dynamic_tensor_guid_t{tensor.guid}, + /*parallel_tensor_shape=*/tensor.attrs.shape, + /*shard_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*accessor=*/std::nullopt, + /*role=*/std::nullopt, }, + }; }; - return invocation_node; -} - -DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mapped_pcg( - MappedParallelComputationGraph const &mpcg) { - - ParallelComputationGraph pcg = pcg_from_mpcg(mpcg); - - auto mk_invocation = - [&](parallel_layer_guid_t layer, - ParallelLayerAttrs const &attrs) -> DynamicNodeInvocation { - if (attrs.op_attrs.is_parallel_replicate()) { - // build replicate invocation - DynamicNodeInvocation repl_inv = build_replicate_invocation( - layer, attrs.op_attrs.require_parallel_replicate(), mpcg); - return repl_inv; - } else { - DynamicNodeAttrs result_attrs{ - /*task_type=*/std::nullopt, - /*device_coord=*/std::nullopt, - /*mapping=*/mpcg_get_mapping_for_layer(mpcg, layer), - /*op_attrs=*/TrainingOperationAttrs{attrs.op_attrs}, - /*pcg_layer_guid=*/dynamic_layer_guid_t{layer}, - /*per_device_op_state=*/std::nullopt, - }; + std::map result_inputs = + transform(invocation_info.incoming, lift_kv_pair); - auto mk_slot = [](TensorSlotName const &slot_name) -> DynamicTensorSlot { - return DynamicTensorSlot{ - /*slot_name=*/slot_name, - /*slot_tensor_role=*/std::nullopt, - }; - }; + std::map result_outputs = + transform(invocation_info.outgoing, lift_kv_pair); - auto mk_value_attrs = - [&](parallel_tensor_guid_t const &tensor) -> DynamicValueAttrs { - ParallelTensorAttrs attrs = get_parallel_tensor_attrs(pcg, tensor); - - return DynamicValueAttrs{ - /*tensor_guid=*/dynamic_tensor_guid_t{tensor}, - /*parallel_tensor_shape=*/attrs.shape, - /*shard_coord=*/std::nullopt, - /*mapping=*/std::nullopt, - /*accessor=*/std::nullopt, - /*role=*/std::nullopt, - }; - }; - - std::unordered_map result_inputs = - map_keys_and_values( - get_incoming_tensors(pcg, layer), mk_slot, mk_value_attrs); - - std::unordered_map result_outputs = - map_keys_and_values( - get_outgoing_tensors(pcg, layer), mk_slot, mk_value_attrs); + DynamicNodeInvocation invocation = DynamicNodeInvocation{ + /*inputs=*/unordered_map_from_map(result_inputs), + /*node_attrs=*/result_attrs, + /*outputs=*/unordered_map_from_map(result_outputs), + }; - DynamicNodeInvocation invocation = DynamicNodeInvocation{ - /*inputs=*/result_inputs, - /*node_attrs=*/result_attrs, - /*outputs=*/result_outputs, - }; + return invocation; +} - return invocation; - }; - }; +DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mapped_pcg( + MappedParallelComputationGraph const &mpcg) { - return dynamic_open_dataflow_graph_from_invocation_set(transform_pairs( - unordered_set_of(get_parallel_layer_attrs_mapping(pcg)), mk_invocation)); + return dynamic_open_dataflow_graph_from_invocation_set( + transform(unordered_set_of(mpcg_get_invocation_set(mpcg)), make_dynamic_node_invocation_from_mapped)); } } // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc b/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc index 64fe2df0be..a348ba77da 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc @@ -21,6 +21,18 @@ bool value_is_pass_expanded(DynamicValueAttrs const &v) { return v.role.has_value(); } +bool node_is_ready_for_pass_expansion(DynamicNodeAttrs const &) { + return true; +} + +bool value_is_ready_for_pass_expansion(DynamicValueAttrs const &) { + return true; +} + +bool slot_is_ready_for_pass_expansion(DynamicTensorSlot const &) { + return true; +} + bool no_part_of_graph_is_pass_expanded(DynamicOpenDataflowGraph const &g) { return no_part_of_dynamic_graph_satisfies( g, node_is_pass_expanded, value_is_pass_expanded, slot_is_pass_expanded); @@ -31,6 +43,11 @@ bool graph_is_fully_pass_expanded(DynamicOpenDataflowGraph const &g) { g, node_is_pass_expanded, value_is_pass_expanded, slot_is_pass_expanded); } +bool graph_is_ready_for_pass_expansion(DynamicOpenDataflowGraph const &g) { + return full_dynamic_graph_satisfies( + g, node_is_ready_for_pass_expansion, value_is_ready_for_pass_expansion, slot_is_ready_for_pass_expansion); +} + DynamicTensorSlot pass_expand_slot(DynamicTensorSlot const &s, FwbTensorType tensor_type) { ASSERT(!slot_is_pass_expanded(s)); @@ -139,6 +156,7 @@ DynamicOpenDataflowGraph perform_pass_expansion(DynamicOpenDataflowGraph const &g) { ASSERT(no_part_of_graph_is_pass_expanded(g)); + ASSERT(graph_is_ready_for_pass_expansion(g)); DynamicOpenDataflowGraph result = flatmap_dynamic_invocation_set( g, [](DynamicNodeInvocation const &invocation) { diff --git a/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc b/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc index d3365ae44c..badd376a8b 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc @@ -1,12 +1,16 @@ #include "task-spec/dynamic_graph/shard_expansion.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h" -#include "utils/bidict/algorithms/filter_keys.h" +#include "utils/bidict/algorithms/bidict_filter_keys.h" #include "utils/containers/get_only.h" #include "utils/containers/map_values2.h" #include "utils/containers/require_same.h" #include "utils/containers/transform.h" #include "utils/optional.h" +#include "utils/containers/binary_merge_disjoint_maps.h" +#include "task-spec/dynamic_graph/dynamic_node_invocation.h" +#include "utils/containers/map_from_unordered.h" +#include "utils/one_to_many/one_to_many_filter_keys.h" namespace FlexFlow { @@ -14,8 +18,74 @@ bool node_is_shard_expanded(DynamicNodeAttrs const &n) { return n.device_coord.has_value(); } +bool node_is_ready_for_shard_expansion(DynamicNodeAttrs const &n) { + if (!n.op_attrs.has_value()) { + return false; + } + + if (n.op_attrs.value().is_pcg_op()) { + if (!n.mapping.has_value()) { + return false; + } + } + + return true; +} + +void require_node_is_ready_for_shard_expansion(DynamicNodeAttrs const &n) { + ASSERT(n.op_attrs.has_value()); + if (n.op_attrs.value().is_pcg_op()) { + ASSERT(n.mapping.has_value()); + } +} + + +bool invocation_is_fully_shard_expanded(DynamicNodeInvocation const &i) { + auto slot_is_shard_expanded = [](DynamicTensorSlot const &) { + return true; + }; + + return invocation_fully_satisfies( + i, + node_is_shard_expanded, + value_is_shard_expanded, + slot_is_shard_expanded); +} + bool value_is_shard_expanded(DynamicValueAttrs const &n) { - return n.shard_coord.has_value(); + return n.shard_coord.has_value() && n.mapping.has_value(); +} + +bool value_is_ready_for_shard_expansion(DynamicValueAttrs const &n) { + return true; +} + +void require_value_is_ready_for_shard_expansion(DynamicValueAttrs const &n) { + return; +} + +bool invocation_is_ready_for_shard_expansion(DynamicNodeInvocation const &i) { + auto slot_is_ready_for_shard_expansion = [](DynamicTensorSlot const &) { + return true; + }; + + return invocation_fully_satisfies( + i, + node_is_ready_for_shard_expansion, + value_is_ready_for_shard_expansion, + slot_is_ready_for_shard_expansion); +} + +void require_invocation_is_ready_for_shard_expansion(DynamicNodeInvocation const &i) { + auto require_slot_is_ready_for_shard_expansion = [](DynamicTensorSlot const &) -> void { + return; + }; + + require_invocation_fully_satisfies( + i, + require_node_is_ready_for_shard_expansion, + require_value_is_ready_for_shard_expansion, + require_slot_is_ready_for_shard_expansion); } bool no_part_of_graph_is_shard_expanded(DynamicOpenDataflowGraph const &g) { @@ -39,20 +109,53 @@ bool graph_is_fully_shard_expanded(DynamicOpenDataflowGraph const &g) { value_is_shard_expanded, slot_is_shard_expanded); } -static bidict + +static OneToMany restrict_tensor_mapping_keys_to_coord( - bidict const + OneToMany const &mapping, ParallelTensorSpaceCoordinate const ¶llel_tensor_coord) { - return filter_keys(mapping, [&](ParallelTensorSpaceCoordinate const &p) { + return one_to_many_filter_keys(mapping, [&](ParallelTensorSpaceCoordinate const &p) { return p == parallel_tensor_coord; }); } +static DynamicNodeInvocationShardingInfo invocation_sharding_info_for_binding( + DynamicNodeInvocation const &i, + MachineSpaceCoordinate const &machine_coord, + OperatorAtomicTaskShardBinding const &binding) { + + auto shard_expand_value_attrs = + [&](DynamicTensorSlot const &s, DynamicValueAttrs const &v) -> DynamicValueAttrsShardingInfo { + ParallelTensorSpaceCoordinate parallel_tensor_coord = + binding.tensor_coords.at(s.slot_name); + + return DynamicValueAttrsShardingInfo{ + /*shard_coord=*/parallel_tensor_coord, + /*mapping=*/restrict_tensor_mapping_keys_to_coord(v.mapping.value(), parallel_tensor_coord), + }; + }; + + DynamicNodeAttrs expanded_node_attrs = [&]() { + DynamicNodeAttrs result = i.node_attrs; + result.device_coord = machine_coord; + return result; + }(); + + return DynamicNodeInvocationShardingInfo{ + /*device_coord=*/machine_coord, + /*value_sharding=*/map_from_unordered( + map_values2( + binary_merge_disjoint_maps(i.inputs, i.outputs), + shard_expand_value_attrs)), + }; +} + static DynamicNodeInvocation shard_invocation_for_binding( DynamicNodeInvocation const &i, MachineSpaceCoordinate const &machine_coord, OperatorAtomicTaskShardBinding const &binding) { + auto shard_expand_value_attrs = [&](DynamicTensorSlot const &s, DynamicValueAttrs const &v) -> DynamicValueAttrs { @@ -63,8 +166,9 @@ static DynamicNodeInvocation shard_invocation_for_binding( result.shard_coord = parallel_tensor_coord; result.mapping = transform( v.mapping, - [&](bidict const - &mapping) { + [&](OneToMany const &mapping) + -> OneToMany + { return restrict_tensor_mapping_keys_to_coord(mapping, parallel_tensor_coord); }); @@ -84,124 +188,19 @@ static DynamicNodeInvocation shard_invocation_for_binding( }; } -static std::unordered_set - perform_shard_expansion_for_replicate(DynamicNodeInvocation const &i) { - auto const &[input_slot, input] = get_only(i.inputs); - auto const &[output_slot, output] = get_only(i.outputs); - - bidict input_mapping = - assert_unwrap(input.mapping); - bidict output_mapping = - assert_unwrap(output.mapping); - - return transform(output_mapping.left_values(), - [&](ParallelTensorSpaceCoordinate const &p) { - ParallelTensorSpaceCoordinate input_p{ - /*sum_component=*/p.sum_component, - /*discard_copy_component=*/nonnegative_int{0}, - /*shard_components=*/p.shard_components, - }; - return shard_invocation_for_binding( - i, - output_mapping.at_l(p), - OperatorAtomicTaskShardBinding{{ - {input_slot.slot_name, input_p}, - {output_slot.slot_name, p}, - }}); - }); -} - -static std::unordered_set - perform_shard_expansion_for_replicate_bwd(DynamicNodeInvocation const &i) { - - std::optional output_grad_opt; - std::optional output_fwd_opt; - std::optional output_grad_slot_opt; - std::optional output_fwd_slot_opt; - - for (auto const &[slot, value] : i.inputs) { - if (slot.slot_tensor_role == DynamicTensorRole{FwbTensorType::GRADIENT}) { - output_grad_slot_opt = slot; - output_grad_opt = value; - } else { - output_fwd_slot_opt = slot; - output_fwd_opt = value; - } - } - - DynamicValueAttrs output_grad = assert_unwrap(output_grad_opt); - DynamicValueAttrs output_fwd = assert_unwrap(output_fwd_opt); - DynamicTensorSlot output_grad_slot = assert_unwrap(output_grad_slot_opt); - DynamicTensorSlot output_fwd_slot = assert_unwrap(output_fwd_slot_opt); - auto const &[input_grad_slot, input_grad] = get_only(i.outputs); - - bidict - output_grad_mapping = assert_unwrap(output_grad.mapping); - bidict - input_grad_mapping = assert_unwrap(input_grad.mapping); - - std::unordered_map, - std::unordered_set> - by_shard; - for (auto const &p : output_grad_mapping.left_values()) { - by_shard[p.shard_components].insert(p); - } - - std::unordered_set result; - for (auto const &[shard_components, replica_coords] : by_shard) { - ParallelTensorSpaceCoordinate src_p{ - nonnegative_int{0}, nonnegative_int{0}, shard_components}; - MachineSpaceCoordinate src_machine = input_grad_mapping.at_l(src_p); - - bidict - replica_mapping; - for (auto const &p : replica_coords) { - replica_mapping.equate(p, output_grad_mapping.at_l(p)); - } - - DynamicValueAttrs sharded_output_grad = output_grad; - sharded_output_grad.mapping = replica_mapping; - sharded_output_grad.shard_coord = src_p; - - DynamicValueAttrs sharded_output_fwd = output_fwd; - sharded_output_fwd.mapping = replica_mapping; - sharded_output_fwd.shard_coord = src_p; - - DynamicValueAttrs sharded_input_grad = input_grad; - sharded_input_grad.mapping = - bidict{ - {src_p, src_machine}}; - sharded_input_grad.shard_coord = src_p; - - DynamicNodeAttrs sharded_node = i.node_attrs; - sharded_node.device_coord = src_machine; - - result.insert(DynamicNodeInvocation{ - /*inputs=*/{ - {output_fwd_slot, sharded_output_fwd}, - {output_grad_slot, sharded_output_grad}, - }, - /*node_attrs=*/sharded_node, - /*outputs=*/ - { - {input_grad_slot, sharded_input_grad}, - }, - }); - } - return result; -} - -static std::unordered_set - perform_shard_expansion_for_copy(DynamicNodeInvocation const &i) { +static std::set + generate_shard_expansion_for_copy(DynamicNodeInvocation const &i) { auto [input_slot, input] = get_only(i.inputs); auto [output_slot, output] = get_only(i.outputs); - bidict input_mapping = + + OneToMany input_mapping = assert_unwrap(input.mapping); require_same(input_mapping.left_values(), assert_unwrap(output.mapping).left_values()); return transform( - input_mapping.left_values(), [&](ParallelTensorSpaceCoordinate const &p) { + input_mapping.left_values(), + [&](ParallelTensorSpaceCoordinate const &p) -> DynamicNodeInvocationShardingInfo { // The machine coord for a copy is inherently nebulous because it // doesn't strictly run in any single location. Further, Realm has the // flexibility to issue a copy operation from anywhere in the machine, @@ -209,9 +208,9 @@ static std::unordered_set // because we expect this to align with the most efficient way to issue // copies in Realm, although the current Realm backend uses a // centralized controller and thus issues copies all from a single node. - MachineSpaceCoordinate machine_coord = input_mapping.at_l(p); + MachineSpaceCoordinate machine_coord = get_only(input_mapping.at_l(p)); - return shard_invocation_for_binding(i, + return invocation_sharding_info_for_binding(i, machine_coord, OperatorAtomicTaskShardBinding{{ {input_slot.slot_name, p}, @@ -222,28 +221,91 @@ static std::unordered_set std::unordered_set perform_shard_expansion_for_invocation(DynamicNodeInvocation const &i) { - if (i.node_attrs.op_attrs.has_value() && - i.node_attrs.op_attrs.value().is_copy()) { - return perform_shard_expansion_for_copy(i); - } - bool const is_replicate = - i.node_attrs.op_attrs.has_value() && - i.node_attrs.op_attrs.value().has() && - i.node_attrs.op_attrs.value() - .get() - .has(); - - // forward replicate - if (is_replicate && i.node_attrs.task_type.has_value() && - i.node_attrs.task_type.value() == DynamicTaskType::FWD) { - return perform_shard_expansion_for_replicate(i); - } + std::unordered_set + shard_expansion_info = generate_shard_expansion_for_invocation(i); + + return transform( + shard_expansion_info, + [&](DynamicNodeInvocationShardingInfo const &s) + -> DynamicNodeInvocation + { + return apply_dynamic_node_invocation_sharding_info(i, s); + }); +} + +bool graph_is_ready_for_shard_expansion(DynamicOpenDataflowGraph const &g) { + auto slot_is_ready_for_shard_expansion = [](DynamicTensorSlot const &) -> bool { + return false; + }; + + return full_dynamic_graph_satisfies(g, + node_is_ready_for_shard_expansion, + value_is_ready_for_shard_expansion, + slot_is_ready_for_shard_expansion); +} + + +void require_graph_is_ready_for_shard_expansion(DynamicOpenDataflowGraph const &g) { + auto require_slot_is_ready_for_shard_expansion = [](DynamicTensorSlot const &) -> void { + return; + }; - // backward replicate - if (is_replicate && i.node_attrs.task_type.has_value() && - i.node_attrs.task_type.value() == DynamicTaskType::BWD) { - return perform_shard_expansion_for_replicate_bwd(i); + return require_full_dynamic_graph_satisfies(g, + require_node_is_ready_for_shard_expansion, + require_value_is_ready_for_shard_expansion, + require_slot_is_ready_for_shard_expansion); +} + +DynamicNodeAttrs apply_dynamic_node_attrs_sharding_info( + DynamicNodeAttrs const &node_attrs, + MachineSpaceCoordinate const &device_coord) +{ + DynamicNodeAttrs result = node_attrs; + result.device_coord = device_coord; + + return result; +} + +DynamicValueAttrs apply_dynamic_value_attrs_sharding_info( + DynamicValueAttrs const &value_attrs, + DynamicValueAttrsShardingInfo const &value_sharding_info) +{ + DynamicValueAttrs result = value_attrs; + result.shard_coord = value_sharding_info.shard_coord; + result.mapping = value_sharding_info.mapping; + return result; +} + +DynamicNodeInvocation apply_dynamic_node_invocation_sharding_info( + DynamicNodeInvocation const &invocation, + DynamicNodeInvocationShardingInfo const &invocation_sharding_info) +{ + require_invocation_is_ready_for_shard_expansion(invocation); + + auto shard_value = [&](DynamicTensorSlot const &slot, DynamicValueAttrs const &value_attrs) -> DynamicValueAttrs { + DynamicValueAttrsShardingInfo sharding_info = invocation_sharding_info.value_sharding.at(slot); + return apply_dynamic_value_attrs_sharding_info(value_attrs, sharding_info); + }; + + DynamicNodeInvocation result = DynamicNodeInvocation{ + /*inputs=*/map_values2(invocation.inputs, shard_value), + /*node_attrs=*/apply_dynamic_node_attrs_sharding_info(invocation.node_attrs, invocation_sharding_info.device_coord), + /*outputs=*/map_values2(invocation.outputs, shard_value), + }; + + ASSERT(invocation_is_fully_shard_expanded(result)); + return result; +} + +std::unordered_set + generate_shard_expansion_for_invocation(DynamicNodeInvocation const &i) +{ + require_invocation_is_ready_for_shard_expansion(i); + + if (i.node_attrs.op_attrs.has_value() && + i.node_attrs.op_attrs.value().is_copy()) { + return unordered_set_of(generate_shard_expansion_for_copy(i)); } MappedOperatorTaskGroup mapping = assert_unwrap(i.node_attrs.mapping); @@ -253,11 +315,11 @@ std::unordered_set return transform( shard_machine_coords, - [&](MachineSpaceCoordinate const &c) -> DynamicNodeInvocation { + [&](MachineSpaceCoordinate const &c) -> DynamicNodeInvocationShardingInfo { OperatorAtomicTaskShardBinding slot_bindings = mapping.get_shard_bindings().at_l(c); - return shard_invocation_for_binding(i, c, slot_bindings); + return invocation_sharding_info_for_binding(i, c, slot_bindings); }); } @@ -265,6 +327,7 @@ DynamicOpenDataflowGraph perform_shard_expansion(DynamicOpenDataflowGraph const &g) { ASSERT(no_part_of_graph_is_shard_expanded(g)); + require_graph_is_ready_for_shard_expansion(g); DynamicOpenDataflowGraph result = flatmap_dynamic_invocation_set(g, [&](DynamicNodeInvocation const &i) { diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/copy_insertion.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/copy_insertion.cc index 2160f6bf82..fdc705dc54 100644 --- a/lib/task-spec/test/src/task-spec/dynamic_graph/copy_insertion.cc +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/copy_insertion.cc @@ -6,11 +6,13 @@ #include "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h" #include "test/utils/doctest/fmt/unordered_set.h" #include +#include "task-spec/dynamic_graph/dynamic_value_attrs.h" +#include "task-spec/dynamic_graph/serializable_dynamic_node_invocation.h" using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("perform_copy_insertion_for_invocation") { + TEST_CASE("copies_for_invocation_inputs") { auto mk_machine_coord = [](nonnegative_int node_idx, nonnegative_int device_idx) -> MachineSpaceCoordinate { @@ -21,6 +23,33 @@ TEST_SUITE(FF_TEST_SUITE) { }; }; + auto mk_slot = [](TensorSlotName const &slot_name) -> DynamicTensorSlot { + return DynamicTensorSlot{ + /*slot_name=*/slot_name, + /*slot_tensor_role=*/mk_dynamic_tensor_role_fwd(), + }; + }; + + auto mk_value = [](size_t src_node_id, + TensorSlotName src_slot_name) + -> DynamicValueAttrs { + return DynamicValueAttrs{ + /*tensor_guid=*/dynamic_tensor_guid_t{ + parallel_tensor_guid_t{ + KwargDataflowOutput{ + Node{src_node_id}, + src_slot_name, + }, + }, + }, + /*parallel_tensor_shape=*/std::nullopt, + /*shard_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*accessor=*/std::nullopt, + /*role=*/std::nullopt, + }; + }; + auto mk_pt_coord = [](nonnegative_int idx1, nonnegative_int idx2, @@ -37,397 +66,567 @@ TEST_SUITE(FF_TEST_SUITE) { }; }; - auto mk_input_shard_binding = [&](ParallelTensorSpaceCoordinate const &c) - -> OperatorAtomicTaskShardBinding { - return OperatorAtomicTaskShardBinding{ - /*tensor_coords=*/{ + size_t invocation_id = 20; + + MachineSpaceCoordinate mc1 = mk_machine_coord(0_n, 0_n); + MachineSpaceCoordinate mc2 = mk_machine_coord(1_n, 0_n); + MachineSpaceCoordinate mc3 = mk_machine_coord(2_n, 0_n); + MachineSpaceCoordinate mc4 = mk_machine_coord(3_n, 0_n); + + SUBCASE("standard operator") { + auto mk_input_shard_binding = [&](ParallelTensorSpaceCoordinate const &c) + -> OperatorAtomicTaskShardBinding { + return OperatorAtomicTaskShardBinding{ + /*tensor_coords=*/{ + { + TensorSlotName::OUTPUT, + c, + }, + }, + }; + }; + + auto mk_shard_binding = [&](ParallelTensorSpaceCoordinate const &c1, + ParallelTensorSpaceCoordinate const &c2, + ParallelTensorSpaceCoordinate const &c3, + ParallelTensorSpaceCoordinate const &c4) + -> OperatorAtomicTaskShardBinding { + return OperatorAtomicTaskShardBinding{ + /*tensor_coords=*/{ + { + TensorSlotName::INPUT, + c1, + }, + { + TensorSlotName::WEIGHT, + c2, + }, + { + TensorSlotName::OUTPUT_1, + c3, + }, + { + TensorSlotName::OUTPUT_2, + c4, + }, + }, + }; + }; + + ParallelTensorSpaceCoordinate mc1_input_coord = + mk_pt_coord(0_n, 0_n, 0_n, 0_n); + ParallelTensorSpaceCoordinate mc1_weight_coord = + mk_pt_coord(0_n, 1_n, 2_n, 0_n); + ParallelTensorSpaceCoordinate mc1_output_1_coord = + mk_pt_coord(1_n, 0_n, 0_n, 1_n); + ParallelTensorSpaceCoordinate mc1_output_2_coord = + mk_pt_coord(3_n, 0_n, 0_n, 0_n); + + ParallelTensorSpaceCoordinate mc2_input_coord = + mk_pt_coord(0_n, 1_n, 0_n, 0_n); + ParallelTensorSpaceCoordinate mc2_weight_coord = + mk_pt_coord(0_n, 4_n, 2_n, 0_n); + ParallelTensorSpaceCoordinate mc2_output_1_coord = + mk_pt_coord(1_n, 2_n, 0_n, 1_n); + ParallelTensorSpaceCoordinate mc2_output_2_coord = + mk_pt_coord(0_n, 0_n, 0_n, 0_n); + + MappedOperatorTaskGroup input_mapping_same = MappedOperatorTaskGroup{ + bidict{ + { + mc1, + mk_input_shard_binding(mc1_input_coord), + }, { - TensorSlotName::OUTPUT, - c, + mc2, + mk_input_shard_binding(mc2_input_coord), }, }, }; - }; - auto mk_shard_binding = [&](ParallelTensorSpaceCoordinate const &c1, - ParallelTensorSpaceCoordinate const &c2, - ParallelTensorSpaceCoordinate const &c3, - ParallelTensorSpaceCoordinate const &c4) - -> OperatorAtomicTaskShardBinding { - return OperatorAtomicTaskShardBinding{ - /*tensor_coords=*/{ + MappedOperatorTaskGroup weight_mapping_same = MappedOperatorTaskGroup{ + bidict{ { - TensorSlotName::INPUT, - c1, + mc1, + mk_input_shard_binding(mc1_weight_coord), }, { - TensorSlotName::WEIGHT, - c2, + mc2, + mk_input_shard_binding(mc2_weight_coord), }, + }, + }; + + MappedOperatorTaskGroup invocation_mapping = MappedOperatorTaskGroup{ + bidict{ { - TensorSlotName::OUTPUT_1, - c3, + mc1, + mk_shard_binding(mc1_input_coord, + mc1_weight_coord, + mc1_output_1_coord, + mc1_output_2_coord), }, { - TensorSlotName::OUTPUT_2, - c4, + mc2, + mk_shard_binding(mc2_input_coord, + mc2_weight_coord, + mc2_output_1_coord, + mc2_output_2_coord), }, }, }; - }; - MachineSpaceCoordinate mc1 = mk_machine_coord(0_n, 0_n); - MachineSpaceCoordinate mc2 = mk_machine_coord(1_n, 0_n); - MachineSpaceCoordinate mc3 = mk_machine_coord(2_n, 0_n); - MachineSpaceCoordinate mc4 = mk_machine_coord(3_n, 0_n); + MappedOperatorTaskGroup invocation_mapping_diff_vs_copy1 = + MappedOperatorTaskGroup{ + bidict{ + { + mc2, + mk_shard_binding(mc2_input_coord, + mc2_weight_coord, + mc2_output_1_coord, + mc2_output_2_coord), + }, + }, + }; - ParallelTensorSpaceCoordinate mc1_input_coord = - mk_pt_coord(0_n, 0_n, 0_n, 0_n); - ParallelTensorSpaceCoordinate mc1_weight_coord = - mk_pt_coord(0_n, 1_n, 2_n, 0_n); - ParallelTensorSpaceCoordinate mc1_output_1_coord = - mk_pt_coord(1_n, 0_n, 0_n, 1_n); - ParallelTensorSpaceCoordinate mc1_output_2_coord = - mk_pt_coord(3_n, 0_n, 0_n, 0_n); - - ParallelTensorSpaceCoordinate mc2_input_coord = - mk_pt_coord(0_n, 1_n, 0_n, 0_n); - ParallelTensorSpaceCoordinate mc2_weight_coord = - mk_pt_coord(0_n, 4_n, 2_n, 0_n); - ParallelTensorSpaceCoordinate mc2_output_1_coord = - mk_pt_coord(1_n, 2_n, 0_n, 1_n); - ParallelTensorSpaceCoordinate mc2_output_2_coord = - mk_pt_coord(0_n, 0_n, 0_n, 0_n); - - MappedOperatorTaskGroup input_mapping_same = MappedOperatorTaskGroup{ - bidict{ - { - mc1, - mk_input_shard_binding(mc1_input_coord), - }, - { - mc2, - mk_input_shard_binding(mc2_input_coord), - }, - }, - }; + DynamicValueAttrs graph_input1 = + mk_value(0, TensorSlotName::OUTPUT); + + DynamicValueAttrs graph_input1_use = + decide_dynamic_value_attrs_mapping( + graph_input1, + get_tensor_bindings_for_slot_name(invocation_mapping, TensorSlotName::INPUT)); + + DynamicValueAttrs graph_input1_use_diff_vs_copy1 = + decide_dynamic_value_attrs_mapping( + graph_input1, + get_tensor_bindings_for_slot_name(invocation_mapping_diff_vs_copy1, TensorSlotName::INPUT)); + + DynamicValueAttrs graph_input2 = + mk_value(1, TensorSlotName::OUTPUT); + + DynamicValueAttrs graph_input2_use = + decide_dynamic_value_attrs_mapping( + graph_input2, + get_tensor_bindings_for_slot_name(invocation_mapping, TensorSlotName::WEIGHT)); + + DynamicValueAttrs invocation_output1 = mk_value(invocation_id, + TensorSlotName::OUTPUT_1); + DynamicValueAttrs invocation_output1_src = + decide_dynamic_value_attrs_mapping( + invocation_output1, + get_tensor_bindings_for_slot_name(invocation_mapping, TensorSlotName::OUTPUT_1)); + + DynamicValueAttrs invocation_output2 = mk_value(invocation_id, + TensorSlotName::OUTPUT_2); + DynamicValueAttrs invocation_output2_src = + decide_dynamic_value_attrs_mapping( + invocation_output2, + get_tensor_bindings_for_slot_name(invocation_mapping, TensorSlotName::OUTPUT_2)); + + DynamicValueAttrs graph_input1_src_same = + decide_dynamic_value_attrs_mapping( + graph_input1, + get_tensor_bindings_for_slot_name(input_mapping_same, TensorSlotName::OUTPUT)); + + DynamicValueAttrs graph_input2_src_same = + decide_dynamic_value_attrs_mapping( + graph_input2, + get_tensor_bindings_for_slot_name(weight_mapping_same, TensorSlotName::OUTPUT)); + + DynamicNodeInvocation input = DynamicNodeInvocation{ + /*inputs=*/{ + { + mk_slot(TensorSlotName::INPUT), + graph_input1, + }, + { + mk_slot(TensorSlotName::WEIGHT), + graph_input2, + }, + }, + /*node_attrs=*/ + DynamicNodeAttrs{ + /*task_type=*/DynamicTaskType::FWD, + /*device_coord=*/std::nullopt, + /*mapping=*/invocation_mapping, + /*op_attrs=*/std::nullopt, + /*layer_guid=*/ + dynamic_layer_guid_t{parallel_layer_guid_t{Node{invocation_id}}}, + /*per_device_op_state=*/std::nullopt, + }, + /*outputs=*/ + { + { + mk_slot(TensorSlotName::OUTPUT_1), + invocation_output1, + }, + { + mk_slot(TensorSlotName::OUTPUT_2), + invocation_output2, + }, + }, + }; - MappedOperatorTaskGroup weight_mapping_same = MappedOperatorTaskGroup{ - bidict{ - { - mc1, - mk_input_shard_binding(mc1_weight_coord), + auto mk_copy = [&](DynamicValueAttrs const &src, + DynamicValueAttrs const &dst) { + return DynamicNodeInvocation{ + /*inputs=*/{{mk_slot(TensorSlotName::INPUT), src}}, + /*node_attrs=*/ + DynamicNodeAttrs{ + /*task_type=*/DynamicTaskType::FWD, + /*device_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*op_attrs*/ TrainingOperationAttrs{CopyAttrs{}}, + /*layer_guid=*/dynamic_layer_guid_t{dynamic_copy_layer_guid_t{}}, + /*per_device_op_state=*/std::nullopt, }, - { - mc2, - mk_input_shard_binding(mc2_weight_coord), + /*outputs=*/{{mk_slot(TensorSlotName::OUTPUT), dst}}, + }; + }; + + SUBCASE("same mapping, no copies") { + std::unordered_map sources_same{ + {graph_input1, graph_input1_src_same}, + {graph_input2, graph_input2_src_same}, + }; + + std::unordered_set result = + copies_for_invocation_inputs(input, sources_same); + + std::unordered_set correct = {}; + + CHECK(result.size() == correct.size()); + CHECK(result == correct); + } + + SUBCASE("copy one tensor, one point") { + MappedOperatorTaskGroup input_mapping_copy1 = MappedOperatorTaskGroup{ + bidict{ + { + mc1, + mk_input_shard_binding(mc1_input_coord), + }, + { + mc3, + mk_input_shard_binding(mc2_input_coord), + }, }, - }, - }; + }; - MappedOperatorTaskGroup invocation_mapping = MappedOperatorTaskGroup{ - bidict{ - { - mc1, - mk_shard_binding(mc1_input_coord, - mc1_weight_coord, - mc1_output_1_coord, - mc1_output_2_coord), + MappedOperatorTaskGroup input_mapping_copy1_diff_vs_use = + MappedOperatorTaskGroup{ + bidict{ + { + mc3, + mk_input_shard_binding(mc2_input_coord), + }, + }, + }; + + DynamicValueAttrs graph_input1_src_copy1 = + decide_dynamic_value_attrs_mapping( + graph_input1, + get_tensor_bindings_for_slot_name(input_mapping_copy1, TensorSlotName::OUTPUT)); + + DynamicValueAttrs graph_input1_src_copy1_diff_vs_use = + decide_dynamic_value_attrs_mapping( + graph_input1, + get_tensor_bindings_for_slot_name(input_mapping_copy1_diff_vs_use, TensorSlotName::OUTPUT)); + + std::unordered_map sources_copy1{ + {graph_input1, graph_input1_src_copy1}, + {graph_input2, graph_input2_src_same}}; + + std::unordered_set result = + copies_for_invocation_inputs(input, sources_copy1); + + std::unordered_set correct = { + mk_copy(graph_input1_src_copy1_diff_vs_use, graph_input1_use_diff_vs_copy1), + }; + + CHECK(result.size() == correct.size()); + CHECK(result == correct); + } + + SUBCASE("copy two tensors, two points") { + MappedOperatorTaskGroup input_mapping_copy2 = MappedOperatorTaskGroup{ + bidict{ + { + mc3, + mk_input_shard_binding(mc1_input_coord), + }, + { + mc4, + mk_input_shard_binding(mc2_input_coord), + }, }, - { - mc2, - mk_shard_binding(mc2_input_coord, - mc2_weight_coord, - mc2_output_1_coord, - mc2_output_2_coord), + }; + MappedOperatorTaskGroup weight_mapping_copy2 = MappedOperatorTaskGroup{ + bidict{ + { + mc4, + mk_input_shard_binding(mc1_weight_coord), + }, + { + mc3, + mk_input_shard_binding(mc2_weight_coord), + }, }, - }, - }; + }; - MappedOperatorTaskGroup invocation_mapping_diff_vs_copy1 = - MappedOperatorTaskGroup{ - bidict{ + DynamicValueAttrs graph_input1_src_copy2 = + decide_dynamic_value_attrs_mapping( + graph_input1, + get_tensor_bindings_for_slot_name(input_mapping_copy2, TensorSlotName::OUTPUT)); + + DynamicValueAttrs graph_input2_src_copy2 = + decide_dynamic_value_attrs_mapping( + graph_input2, + get_tensor_bindings_for_slot_name(weight_mapping_copy2, TensorSlotName::OUTPUT)); + + std::unordered_map sources_copy2{ + {graph_input1, graph_input1_src_copy2}, + {graph_input2, graph_input2_src_copy2}}; + + std::unordered_set result = + copies_for_invocation_inputs(input, sources_copy2); + + std::unordered_set correct = { + mk_copy(graph_input1_src_copy2, graph_input1_use), + mk_copy(graph_input2_src_copy2, graph_input2_use), + }; + + CHECK(result.size() == correct.size()); + CHECK(result == correct); + } + } + + SUBCASE("replicate operator") { + + auto mk_shard_binding = [&](ParallelTensorSpaceCoordinate const &c1, + ParallelTensorSpaceCoordinate const &c2) + -> OperatorAtomicTaskShardBinding { + return OperatorAtomicTaskShardBinding{ + /*tensor_coords=*/{ + { + TensorSlotName::INPUT, + c1, + }, { - mc2, - mk_shard_binding(mc2_input_coord, - mc2_weight_coord, - mc2_output_1_coord, - mc2_output_2_coord), + TensorSlotName::OUTPUT, + c2, }, }, }; - auto mk_slot = [](TensorSlotName const &slot_name) -> DynamicTensorSlot { - return DynamicTensorSlot{ - /*slot_name=*/slot_name, - /*slot_tensor_role=*/mk_dynamic_tensor_role_fwd(), }; - }; - auto mk_value = [&](size_t src_node_id, - TensorSlotName src_slot_name, - MappedOperatorTaskGroup const &mapping, - std::optional const &use_slot_name) - -> DynamicValueAttrs { - return DynamicValueAttrs{ - /*tensor_guid=*/dynamic_tensor_guid_t{parallel_tensor_guid_t{ - KwargDataflowOutput{ - Node{src_node_id}, - src_slot_name, + ParallelTensorSpaceCoordinate mc_input_coord = + mk_pt_coord(0_n, 0_n, 0_n, 0_n); + + ParallelTensorSpaceCoordinate mc1_output_coord = + mk_pt_coord(0_n, 0_n, 0_n, 0_n); + ParallelTensorSpaceCoordinate mc2_output_coord = + mk_pt_coord(0_n, 1_n, 0_n, 0_n); + + MappedOperatorTaskGroup invocation_mapping = MappedOperatorTaskGroup{ + bidict{ + { + mc1, + mk_shard_binding(mc_input_coord, + mc1_output_coord), }, - }}, - /*parallel_tensor_shape=*/std::nullopt, - /*shard_coord=*/std::nullopt, - /*mapping=*/ - transform(use_slot_name, - [&](TensorSlotName s) { - return get_tensor_bindings_for_slot_name(mapping, s); - }), - /*accessor=*/std::nullopt, - /*role=*/std::nullopt, + { + mc2, + mk_shard_binding(mc_input_coord, + mc2_output_coord), + }, + }, }; - }; - size_t invocation1_id = 20; - - DynamicValueAttrs graph_input1 = - mk_value(0, TensorSlotName::OUTPUT, invocation_mapping, std::nullopt); - DynamicValueAttrs graph_input1_use = mk_value( - 0, TensorSlotName::OUTPUT, invocation_mapping, TensorSlotName::INPUT); - DynamicValueAttrs graph_input1_use_diff_vs_copy1 = - mk_value(0, - TensorSlotName::OUTPUT, - invocation_mapping_diff_vs_copy1, - TensorSlotName::INPUT); - DynamicValueAttrs graph_input2 = - mk_value(1, TensorSlotName::OUTPUT, invocation_mapping, std::nullopt); - DynamicValueAttrs graph_input2_use = mk_value( - 1, TensorSlotName::OUTPUT, invocation_mapping, TensorSlotName::WEIGHT); - DynamicValueAttrs invocation1_output1 = mk_value(invocation1_id, - TensorSlotName::OUTPUT_1, - invocation_mapping, - std::nullopt); - DynamicValueAttrs invocation1_output1_src = - mk_value(invocation1_id, - TensorSlotName::OUTPUT_1, - invocation_mapping, - TensorSlotName::OUTPUT_1); - DynamicValueAttrs invocation1_output2 = mk_value(invocation1_id, - TensorSlotName::OUTPUT_2, - invocation_mapping, - std::nullopt); - DynamicValueAttrs invocation1_output2_src = - mk_value(invocation1_id, - TensorSlotName::OUTPUT_2, - invocation_mapping, - TensorSlotName::OUTPUT_2); - - DynamicValueAttrs graph_input1_src_same = mk_value( - 0, TensorSlotName::OUTPUT, input_mapping_same, TensorSlotName::OUTPUT); - DynamicValueAttrs graph_input2_src_same = mk_value( - 1, TensorSlotName::OUTPUT, weight_mapping_same, TensorSlotName::OUTPUT); - - DynamicNodeInvocation input = DynamicNodeInvocation{ + DynamicValueAttrs graph_input_unmapped = + mk_value(0, TensorSlotName::OUTPUT); + DynamicValueAttrs graph_input_use_mapped = + decide_dynamic_value_attrs_mapping( + graph_input_unmapped, + get_tensor_bindings_for_slot_name(invocation_mapping, TensorSlotName::INPUT)); + + DynamicValueAttrs invocation_output_unmapped = + mk_value(invocation_id, TensorSlotName::OUTPUT); + DynamicValueAttrs invocation_output_src_mapped = + decide_dynamic_value_attrs_mapping( + invocation_output_unmapped, + get_tensor_bindings_for_slot_name(invocation_mapping, TensorSlotName::OUTPUT)); + + DynamicNodeInvocation input = DynamicNodeInvocation{ /*inputs=*/{ { mk_slot(TensorSlotName::INPUT), - graph_input1, - }, - { - mk_slot(TensorSlotName::WEIGHT), - graph_input2, + graph_input_unmapped, }, }, - /*node_attrs=*/ - DynamicNodeAttrs{ + /*node_attrs=*/DynamicNodeAttrs{ /*task_type=*/DynamicTaskType::FWD, /*device_coord=*/std::nullopt, /*mapping=*/invocation_mapping, - /*op_attrs=*/std::nullopt, - /*layer_guid=*/ - dynamic_layer_guid_t{parallel_layer_guid_t{Node{20}}}, - /*per_device_op_state=*/std::nullopt, - }, - /*outputs=*/ - { - { - mk_slot(TensorSlotName::OUTPUT_1), - invocation1_output1, - }, - { - mk_slot(TensorSlotName::OUTPUT_2), - invocation1_output2, - }, - }, - }; - - DynamicNodeInvocation mapped = DynamicNodeInvocation{ - /*inputs=*/{ - { - mk_slot(TensorSlotName::INPUT), - graph_input1_use, + /*op_attrs=*/TrainingOperationAttrs{ + PCGOperatorAttrs{ + ReplicateAttrs{ + 2_p, + }, + }, }, - { - mk_slot(TensorSlotName::WEIGHT), - graph_input2_use, + /*layer_guid=*/dynamic_layer_guid_t{ + parallel_layer_guid_t{ + Node{invocation_id}, + }, }, - }, - /*node_attrs=*/ - DynamicNodeAttrs{ - /*task_type=*/DynamicTaskType::FWD, - /*device_coord=*/std::nullopt, - /*mapping=*/invocation_mapping, - /*op_attrs=*/std::nullopt, - /*layer_guid=*/ - dynamic_layer_guid_t{parallel_layer_guid_t{Node{20}}}, /*per_device_op_state=*/std::nullopt, }, - /*outputs=*/ - { - { - mk_slot(TensorSlotName::OUTPUT_1), - invocation1_output1_src, - }, + /*outputs=*/{ { - mk_slot(TensorSlotName::OUTPUT_2), - invocation1_output2_src, + mk_slot(TensorSlotName::OUTPUT), + invocation_output_unmapped, }, }, - }; - - auto mk_copy = [&](DynamicValueAttrs const &src, - DynamicValueAttrs const &dst) { - return DynamicNodeInvocation{ - /*inputs=*/{{mk_slot(TensorSlotName::INPUT), src}}, - /*node_attrs=*/ - DynamicNodeAttrs{ - /*task_type=*/DynamicTaskType::FWD, - /*device_coord=*/std::nullopt, - /*mapping=*/std::nullopt, - /*op_attrs*/ TrainingOperationAttrs{CopyAttrs{}}, - /*layer_guid=*/dynamic_layer_guid_t{dynamic_copy_layer_guid_t{}}, - /*per_device_op_state=*/std::nullopt, - }, - /*outputs=*/{{mk_slot(TensorSlotName::OUTPUT), dst}}, }; - }; - SUBCASE("same mapping, no copies") { - std::unordered_map sources_same{ - {graph_input1, graph_input1_src_same}, - {graph_input2, graph_input2_src_same}}; - - std::unordered_set result = - perform_copy_insertion_for_invocation(input, sources_same); - - std::unordered_set correct = {mapped}; - - CHECK(result.size() == correct.size()); - CHECK(result == correct); - } - - SUBCASE("copy one tensor, one point") { - MappedOperatorTaskGroup input_mapping_copy1 = MappedOperatorTaskGroup{ - bidict{ - { - mc1, - mk_input_shard_binding(mc1_input_coord), - }, - { - mc3, - mk_input_shard_binding(mc2_input_coord), - }, + std::unordered_map unmapped_to_mapped_source_value = { + { + graph_input_unmapped, + decide_dynamic_value_attrs_mapping( + graph_input_unmapped, + OneToMany{ + { + mc_input_coord, + {mc3}, + }, + }) }, - }; - MappedOperatorTaskGroup input_mapping_copy1_diff_vs_use = - MappedOperatorTaskGroup{ - bidict{ - { - mc3, - mk_input_shard_binding(mc2_input_coord), - }, - }, - }; + }; - DynamicValueAttrs graph_input1_src_copy1 = - mk_value(0, - TensorSlotName::OUTPUT, - input_mapping_copy1, - TensorSlotName::OUTPUT); - DynamicValueAttrs graph_input1_src_copy1_diff_vs_use = - mk_value(0, - TensorSlotName::OUTPUT, - input_mapping_copy1_diff_vs_use, - TensorSlotName::OUTPUT); - - std::unordered_map sources_copy1{ - {graph_input1, graph_input1_src_copy1}, - {graph_input2, graph_input2_src_same}}; - - std::unordered_set result = - perform_copy_insertion_for_invocation(input, sources_copy1); - - std::unordered_set correct = { - mapped, - mk_copy(graph_input1_src_copy1_diff_vs_use, - graph_input1_use_diff_vs_copy1), - }; + std::unordered_set result = copies_for_invocation_inputs( + input, unmapped_to_mapped_source_value); - CHECK(result.size() == correct.size()); - CHECK(result == correct); - } + std::unordered_set correct = {}; - SUBCASE("copy two tensors, two points") { - MappedOperatorTaskGroup input_mapping_copy2 = MappedOperatorTaskGroup{ - bidict{ - { - mc3, - mk_input_shard_binding(mc1_input_coord), - }, - { - mc4, - mk_input_shard_binding(mc2_input_coord), - }, - }, - }; - MappedOperatorTaskGroup weight_mapping_copy2 = MappedOperatorTaskGroup{ - bidict{ - { - mc4, - mk_input_shard_binding(mc1_weight_coord), - }, - { - mc3, - mk_input_shard_binding(mc2_weight_coord), - }, - }, - }; + nlohmann::json result_j = transform(result, dynamic_node_invocation_to_serializable); + nlohmann::json correct_j = transform(correct, dynamic_node_invocation_to_serializable); - DynamicValueAttrs graph_input1_src_copy2 = - mk_value(0, - TensorSlotName::OUTPUT, - input_mapping_copy2, - TensorSlotName::OUTPUT); - DynamicValueAttrs graph_input2_src_copy2 = - mk_value(1, - TensorSlotName::OUTPUT, - weight_mapping_copy2, - TensorSlotName::OUTPUT); - - std::unordered_map sources_copy2{ - {graph_input1, graph_input1_src_copy2}, - {graph_input2, graph_input2_src_copy2}}; - - std::unordered_set result = - perform_copy_insertion_for_invocation(input, sources_copy2); - - std::unordered_set correct = { - mapped, - mk_copy(graph_input1_src_copy2, graph_input1_use), - mk_copy(graph_input2_src_copy2, graph_input2_use), - }; - - CHECK(result.size() == correct.size()); - CHECK(result == correct); + CHECK(result_j == correct_j); } + + // SUBCASE("reduction operator") { + + // auto mk_shard_binding = [&](ParallelTensorSpaceCoordinate const &c1, + // ParallelTensorSpaceCoordinate const &c2) + // -> OperatorAtomicTaskShardBinding { + // return OperatorAtomicTaskShardBinding{ + // /*tensor_coords=*/{ + // { + // TensorSlotName::INPUT, + // c1, + // }, + // { + // TensorSlotName::OUTPUT, + // c2, + // }, + // }, + // }; + // }; + + // ParallelTensorSpaceCoordinate mc1_input_coord = + // mk_pt_coord(0_n, 0_n, 0_n, 0_n); + // ParallelTensorSpaceCoordinate mc2_input_coord = + // mk_pt_coord(1_n, 0_n, 0_n, 0_n); + + // ParallelTensorSpaceCoordinate mc_output_coord = + // mk_pt_coord(0_n, 0_n, 0_n, 0_n); + + // MappedOperatorTaskGroup invocation_mapping = MappedOperatorTaskGroup{ + // bidict{ + // { + // mc3, + // mk_shard_binding(mc1_input_coord, + // mc_output_coord), + // }, + // }, + // }; + + // DynamicValueAttrs graph_input_unmapped = + // mk_value(0, TensorSlotName::OUTPUT); + // DynamicValueAttrs graph_input_use_mapped = + // decide_dynamic_value_attrs_mapping( + // graph_input_unmapped, + // get_tensor_bindings_for_slot_name(invocation_mapping, TensorSlotName::INPUT)); + + // DynamicValueAttrs invocation_output_unmapped = + // mk_value(invocation_id, TensorSlotName::OUTPUT); + // DynamicValueAttrs invocation_output_src_mapped = + // decide_dynamic_value_attrs_mapping( + // invocation_output_unmapped, + // get_tensor_bindings_for_slot_name(invocation_mapping, TensorSlotName::OUTPUT)); + + // DynamicNodeInvocation input = DynamicNodeInvocation{ + // /*inputs=*/{ + // { + // mk_slot(TensorSlotName::INPUT), + // graph_input_unmapped, + // }, + // }, + // /*node_attrs=*/DynamicNodeAttrs{ + // /*task_type=*/DynamicTaskType::FWD, + // /*device_coord=*/std::nullopt, + // /*mapping=*/invocation_mapping, + // /*op_attrs=*/TrainingOperationAttrs{ + // PCGOperatorAttrs{ + // ReductionAttrs{ + // /*reduction_degree=*/2_p, + // }, + // }, + // }, + // /*layer_guid=*/dynamic_layer_guid_t{ + // parallel_layer_guid_t{ + // Node{invocation_id}, + // }, + // }, + // /*per_device_op_state=*/std::nullopt, + // }, + // /*outputs=*/{ + // { + // mk_slot(TensorSlotName::OUTPUT), + // invocation_output_unmapped, + // }, + // }, + // }; + + // std::unordered_map unmapped_to_mapped_source_value = { + // { + // graph_input_unmapped, + // decide_dynamic_value_attrs_mapping( + // graph_input_unmapped, + // OneToMany{ + // { + // mc1_input_coord, + // {mc1}, + // }, + // { + // mc2_input_coord, + // {mc2}, + // }, + // }) + // }, + // }; + + // std::unordered_set result = copies_for_invocation_inputs( + // input, unmapped_to_mapped_source_value); + + // std::unordered_set correct = {}; + + // nlohmann::json result_j = transform(result, dynamic_node_invocation_to_serializable); + // nlohmann::json correct_j = transform(correct, dynamic_node_invocation_to_serializable); + + // CHECK(result_j == correct_j); + // } } } diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc new file mode 100644 index 0000000000..9f8aeee726 --- /dev/null +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc @@ -0,0 +1,396 @@ +#include +#include "task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.h" +#include "utils/containers/require_only_key.h" +#include "op-attrs/ops/element_unary.h" +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("make_dynamic_node_invocation_from_mapped") { + SUBCASE("Replicate") { + MachineSpaceCoordinate gpu0 = MachineSpaceCoordinate{0_n, 0_n, DeviceType::GPU}; + MachineSpaceCoordinate gpu1 = MachineSpaceCoordinate{0_n, 1_n, DeviceType::GPU}; + + ParallelTensorSpaceCoordinate tensor_coord0 = ParallelTensorSpaceCoordinate{ + /*sum_component=*/0_n, + /*discard_copy_component=*/0_n, + /*shard_component=*/FFOrdered{0_n}, + }; + + ParallelTensorSpaceCoordinate tensor_coord1 = ParallelTensorSpaceCoordinate{ + /*sum_component=*/0_n, + /*discard_copy_component=*/1_n, + /*shard_component=*/FFOrdered{0_n}, + }; + + MappedOperatorTaskGroup mapping = MappedOperatorTaskGroup{ + { + { + gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + { + gpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord1}, + }}, + }, + }, + }; + + ParallelTensorShape input_shape = ParallelTensorShape{ + /*dims=*/ParallelTensorDims{ + /*shard_dims=*/FFOrdered{ + ShardParallelDim{8_p, 2_p}, + ShardParallelDim{5_p, 1_p}, + }, + /*replica_dims=*/ReplicaParallelDimSet{ + SumDegree{1_p}, + DiscardCopyDegree{1_p}, + }, + }, + /*data_type=*/DataType::FLOAT, + }; + + ParallelTensorShape output_shape = [&] { + ParallelTensorShape shape = input_shape; + shape.dims.replica_dims.discard_copy_degree = DiscardCopyDegree{2_p}; + return shape; + }(); + + PCGOperatorAttrs op_attrs = PCGOperatorAttrs{ + ReplicateAttrs{ + 2_p, + }, + }; + + parallel_layer_guid_t layer_guid = parallel_layer_guid_t{Node{0}}; + parallel_tensor_guid_t input_tensor_guid = parallel_tensor_guid_t{ + KwargDataflowOutput{ + Node{5}, + TensorSlotName::OUTPUT, + }, + }; + parallel_tensor_guid_t output_tensor_guid = parallel_tensor_guid_t{ + KwargDataflowOutput{ + Node{0}, + TensorSlotName::OUTPUT, + }, + }; + + MappedParallelLayerInvocationInfo input = MappedParallelLayerInvocationInfo{ + /*incoming=*/{ + { + TensorSlotName::INPUT, + ParallelTensorInfo{ + /*guid=*/input_tensor_guid, + /*attrs=*/ParallelTensorAttrs{ + /*shape=*/input_shape, + /*create_grad=*/CreateGrad::YES, + }, + }, + }, + }, + /*layer_info=*/MappedParallelLayerInfo{ + /*guid=*/layer_guid, + /*attrs=*/ParallelLayerAttrs{ + /*op_attrs=*/op_attrs, + /*name=*/std::nullopt, + }, + /*mapping=*/mapping, + }, + /*outgoing=*/{ + { + TensorSlotName::OUTPUT, + ParallelTensorInfo{ + /*guid=*/output_tensor_guid, + /*attrs=*/ParallelTensorAttrs{ + /*shape=*/output_shape, + /*create_grad=*/CreateGrad::YES, + }, + }, + }, + }, + }; + + DynamicNodeInvocation result = make_dynamic_node_invocation_from_mapped(input); + + DynamicNodeInvocation correct = DynamicNodeInvocation{ + /*inputs=*/{ + { + DynamicTensorSlot{ + TensorSlotName::INPUT, + /*slot_tensor_role=*/std::nullopt, + }, + DynamicValueAttrs{ + /*tensor_guid=*/dynamic_tensor_guid_t{input_tensor_guid}, + /*parallel_tensor_shape=*/input_shape, + /*shard_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*accessor=*/std::nullopt, + /*role=*/std::nullopt, + }, + }, + }, + /*node_attrs=*/DynamicNodeAttrs{ + /*task_type=*/std::nullopt, + /*device_coord=*/std::nullopt, + /*mapping=*/mapping, + /*op_attrs=*/TrainingOperationAttrs{op_attrs}, + /*layer_guid=*/dynamic_layer_guid_t{layer_guid}, + /*per_device_op_state=*/std::nullopt, + }, + /*outputs=*/{ + { + DynamicTensorSlot{ + TensorSlotName::OUTPUT, + /*slot_tensor_role=*/std::nullopt, + }, + DynamicValueAttrs{ + /*tensor_guid=*/dynamic_tensor_guid_t{output_tensor_guid}, + /*parallel_tensor_shape=*/output_shape, + /*shard_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*accessor=*/std::nullopt, + /*role=*/std::nullopt, + }, + }, + } + }; + + CHECK(result == correct); + } + + // SUBCASE("standard op") { + // + // } + } + + // TEST_CASE("make_dynamic_open_dataflow_graph_from_mapped_pcg") { + // positive_int batch_size = 10_p; + // positive_int data_dim = 16_p; + // positive_int hidden_dim = 32_p; + // positive_int output_dim = 1_p; + + // auto make_layer_attrs = [](auto const &op_attrs) -> ParallelLayerAttrs { + // return ParallelLayerAttrs{ + // /*op_attrs=*/PCGOperatorAttrs{op_attrs}, + // /*name=*/std::nullopt, + // }; + // }; + + + // TensorShape output_tensor_shape = TensorShape{ + // TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; + + // TensorShape label_tensor_shape = TensorShape{ + // TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; + + // ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + // TensorShape input_tensor_shape = TensorShape{ + // TensorDims{FFOrdered{batch_size, data_dim}}, DataType::FLOAT}; + + // ParallelLayerAddedResult inputs_layer = + // pcg_add_input_layer(pcg, input_tensor_shape); + // parallel_tensor_guid_t t_input = + // require_only_key(inputs_layer.outputs, TensorSlotName::OUTPUT); + + // ParallelLayerAddedResult inputs_layer_2 = + // pcg_add_input_layer(pcg, input_tensor_shape); + // parallel_tensor_guid_t t_input_2 = + // require_only_key(inputs_layer_2.outputs, TensorSlotName::OUTPUT); + + // ElementBinaryAttrs add_attrs = ElementBinaryAttrs{ + // OperatorType::EW_ADD, + // DataType::FLOAT, + // false, + // false, + // }; + + // ParallelLayerAddedResult add_operator_1 = + // add_parallel_layer(pcg, + // make_layer_attrs(add_attrs), + // { + // { + // TensorSlotName::LHS_INPUT, + // t_input, + // }, + // { + // TensorSlotName::RHS_INPUT, + // t_input_2, + // }, + // }, + // /*weights=*/{}); + + // parallel_tensor_guid_t t_add_1 = + // require_only_key(add_operator_1.outputs, TensorSlotName::OUTPUT); + + // positive_int replicate_degree = 2_p; + // ReplicateAttrs repl_attrs = ReplicateAttrs{replicate_degree}; + // ParallelLayerAddedResult repl_operator_1 = + // add_parallel_layer(pcg, + // make_layer_attrs(repl_attrs), + // { + // { + // TensorSlotName::INPUT, + // t_add_1, + // }, + // }, + // /*weight=*/{}); + + // parallel_tensor_guid_t t_repl_1 = + // require_only_key(repl_operator_1.outputs, TensorSlotName::OUTPUT); + + // ParallelLayerAddedResult relu_operator_1 = + // add_parallel_layer(pcg, + // make_layer_attrs(make_relu_attrs()), + // /*inputs=*/ + // { + // { + // TensorSlotName::INPUT, + // t_repl_1, + // }, + // }, + // /*weights=*/{}); + + // parallel_tensor_guid_t t_relu_1 = + // require_only_key(relu_operator_1.outputs, TensorSlotName::OUTPUT); + + // MachineSpaceCoordinate gpu0{0_n, 0_n, DeviceType::GPU}; + // MachineSpaceCoordinate gpu1{0_n, 1_n, DeviceType::GPU}; + + // ParallelTensorSpaceCoordinate tensor_coord0{ + // /*sum_component=*/0_n, + // /*discard_copy_component=*/0_n, + // /*shard_component=*/FFOrdered{0_n}}; + // ParallelTensorSpaceCoordinate tensor_coord1{ + // /*sum_component=*/0_n, + // /*discard_copy_component=*/1_n, + // /*shard_component=*/FFOrdered{0_n}}; + + // MappedOperatorTaskGroup input_1_mapping = MappedOperatorTaskGroup{ + // { + // { + // gpu0, + // OperatorAtomicTaskShardBinding{{ + // {TensorSlotName::OUTPUT, tensor_coord0}, + // }}, + // }, + // }, + // }; + + // MappedOperatorTaskGroup input_2_mapping = MappedOperatorTaskGroup{ + // { + // { + // gpu0, + // OperatorAtomicTaskShardBinding{{ + // {TensorSlotName::OUTPUT, tensor_coord0}, + // }}, + // }, + // }, + // }; + + // MappedOperatorTaskGroup add_operator_1_mapping = MappedOperatorTaskGroup{ + // { + // { + // gpu0, + // OperatorAtomicTaskShardBinding{{ + // {TensorSlotName::LHS_INPUT, tensor_coord0}, + // {TensorSlotName::RHS_INPUT, tensor_coord0}, + // {TensorSlotName::OUTPUT, tensor_coord0}, + // }}, + // }, + // }, + // }; + + // MappedOperatorTaskGroup repl_operator_1_mapping = MappedOperatorTaskGroup{ + // { + // { + // gpu0, + // OperatorAtomicTaskShardBinding{{ + // {TensorSlotName::OUTPUT, tensor_coord0}, + // }}, + // }, + // { + // gpu1, + // OperatorAtomicTaskShardBinding{{ + // {TensorSlotName::OUTPUT, tensor_coord1}, + // }}, + // }, + // }, + // }; + + // MappedOperatorTaskGroup relu_operator_1_mapping = MappedOperatorTaskGroup{ + // { + // { + // gpu0, + // OperatorAtomicTaskShardBinding{{ + // {TensorSlotName::INPUT, tensor_coord0}, + // {TensorSlotName::OUTPUT, tensor_coord0}, + // }}, + // }, + // { + // gpu1, + // OperatorAtomicTaskShardBinding{{ + // {TensorSlotName::INPUT, tensor_coord1}, + // {TensorSlotName::OUTPUT, tensor_coord1}, + // }}, + // }, + // }, + // }; + + // MappedParallelComputationGraph mpcg = mapped_pcg_from_pcg_and_mapped_op_task_groups( + // /*pcg=*/pcg, + // /*mapped_op_task_groups=*/{ + // { + // inputs_layer.parallel_layer, + // input_1_mapping, + // }, + // { + // inputs_layer_2.parallel_layer, + // input_2_mapping, + // }, + // { + // add_operator_1.parallel_layer, + // add_operator_1_mapping, + // }, + // { + // repl_operator_1.parallel_layer, + // repl_operator_1_mapping, + // }, + // { + // relu_operator_1.parallel_layer, + // relu_operator_1_mapping, + // }, + // }); + + + // DynamicOpenDataflowGraph result = make_dynamic_open_dataflow_graph_from_mapped_pcg(mpcg); + + // DynamicNodeInvocation input_1_invocation = DynamicNodeInvocation{ + // DynamicNodeAttrs{ + // /*task_type=*/std::nullopt, + // /*device_coord=*/std::nullopt, + // /*mapping=*/input_1_mapping, + // /*op_attrs=*/TrainingOperationAttrs{ + // /*pcg_layer_guid=*/ + // /*per_device_op_state=*/std::nullopt, + // }, + // }; + + // DynamicNodeInvocation input_2_invocation = + // DynamicNodeInvocation add_operator_1_invocation = + // DynamicNodeInvocation repl_operator_1_invocation = + // DynamicNodeInvocation relu_operator_1_invocation = + + // DynamicOpenDataflowGraph correct = dynamic_open_dataflow_graph_from_invocation_set( + // /*invocations=*/{ + + // }, + // }; + // } +} diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/shard_expansion.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/shard_expansion.cc index efe21146db..19c21f5f89 100644 --- a/lib/task-spec/test/src/task-spec/dynamic_graph/shard_expansion.cc +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/shard_expansion.cc @@ -4,8 +4,11 @@ #include "task-spec/dynamic_graph/dynamic_copy_layer_guid_t.dtg.h" #include "task-spec/dynamic_graph/training_operation_attrs.dtg.h" #include "test/utils/doctest/fmt/unordered_set.h" -#include "utils/bidict/algorithms/filter_keys.h" #include +#include "task-spec/dynamic_graph/dynamic_tensor_role.h" +#include "op-attrs/ops/element_unary.h" +#include "utils/one_to_many/one_to_many_filter_keys.h" +#include "utils/one_to_many/one_to_many_filter_values.h" using namespace ::FlexFlow; @@ -43,15 +46,18 @@ DynamicTensorSlot mk_slot(TensorSlotName const &slot_name) { DynamicValueAttrs mk_value(size_t src_node_id, TensorSlotName src_slot_name, - bidict - tensor_binding, - std::optional const &shard_coord) { + OneToMany const &tensor_binding, + std::optional const &shard_coord, + std::optional const &role = std::nullopt) { + + OneToMany mapping = tensor_binding; if (shard_coord.has_value()) { - tensor_binding = filter_keys(tensor_binding, - [&](ParallelTensorSpaceCoordinate const &p) { - return p == shard_coord.value(); - }); + mapping = one_to_many_filter_keys(mapping, + [&](ParallelTensorSpaceCoordinate const &p) { + return p == shard_coord.value(); + }); } + return DynamicValueAttrs{ /*tensor_guid=*/dynamic_tensor_guid_t{parallel_tensor_guid_t{ KwargDataflowOutput{ @@ -61,173 +67,148 @@ DynamicValueAttrs }}, /*parallel_tensor_shape=*/std::nullopt, /*shard_coord=*/shard_coord, - /*mapping=*/ - tensor_binding, + /*mapping=*/mapping, /*accessor=*/std::nullopt, - /*role=*/std::nullopt, + /*role=*/role, }; }; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("perform_shard_expansion_for_invocation") { - auto mk_shard_binding = [&](ParallelTensorSpaceCoordinate const &c1, - ParallelTensorSpaceCoordinate const &c2, - ParallelTensorSpaceCoordinate const &c3, - ParallelTensorSpaceCoordinate const &c4) - -> OperatorAtomicTaskShardBinding { - return OperatorAtomicTaskShardBinding{ - /*tensor_coords=*/{ - { - TensorSlotName::INPUT, - c1, - }, - { - TensorSlotName::WEIGHT, - c2, - }, - { - TensorSlotName::OUTPUT_1, - c3, - }, - { - TensorSlotName::OUTPUT_2, - c4, - }, - }, - }; - }; - - MachineSpaceCoordinate mc1 = mk_machine_coord(0_n, 0_n); - MachineSpaceCoordinate mc2 = mk_machine_coord(2_n, 0_n); - - ParallelTensorSpaceCoordinate mc1_input_coord = - mk_pt_coord(0_n, 0_n, 0_n, 0_n); - ParallelTensorSpaceCoordinate mc1_weight_coord = - mk_pt_coord(0_n, 1_n, 2_n, 0_n); - ParallelTensorSpaceCoordinate mc1_output_1_coord = - mk_pt_coord(1_n, 0_n, 0_n, 1_n); - ParallelTensorSpaceCoordinate mc1_output_2_coord = - mk_pt_coord(3_n, 0_n, 0_n, 0_n); - - ParallelTensorSpaceCoordinate mc2_input_coord = - mk_pt_coord(0_n, 1_n, 0_n, 0_n); - ParallelTensorSpaceCoordinate mc2_weight_coord = - mk_pt_coord(0_n, 4_n, 2_n, 0_n); - ParallelTensorSpaceCoordinate mc2_output_1_coord = - mk_pt_coord(1_n, 2_n, 0_n, 1_n); - ParallelTensorSpaceCoordinate mc2_output_2_coord = - mk_pt_coord(0_n, 0_n, 0_n, 0_n); - - MappedOperatorTaskGroup mapped_task_group = MappedOperatorTaskGroup{ - bidict{ - { - mc1, - mk_shard_binding(mc1_input_coord, - mc1_weight_coord, - mc1_output_1_coord, - mc1_output_2_coord), - }, - { - mc2, - mk_shard_binding(mc2_input_coord, - mc2_weight_coord, - mc2_output_1_coord, - mc2_output_2_coord), - }, - }, - }; - + TEST_CASE("generate_shard_expansion_for_invocation") { auto mk_op_value = [&](size_t src_node_id, TensorSlotName src_slot_name, TensorSlotName use_slot_name, - std::optional const &shard_coord) + MappedOperatorTaskGroup const &mapped_task_group, + std::optional const &shard_coord, + std::optional const &role = std::nullopt) -> DynamicValueAttrs { - bidict + OneToMany tensor_binding = get_tensor_bindings_for_slot_name(mapped_task_group, use_slot_name); - return mk_value(src_node_id, src_slot_name, tensor_binding, shard_coord); + return mk_value(src_node_id, src_slot_name, tensor_binding, shard_coord, role); }; - DynamicNodeInvocation input = DynamicNodeInvocation{ - /*inputs=*/{ - { - mk_slot(TensorSlotName::INPUT), - mk_op_value(0, - TensorSlotName::OUTPUT, - TensorSlotName::INPUT, - std::nullopt), - }, - { - mk_slot(TensorSlotName::WEIGHT), - mk_op_value(1, - TensorSlotName::OUTPUT, - TensorSlotName::WEIGHT, - std::nullopt), - }, - }, - /*node_attrs=*/ - DynamicNodeAttrs{ - /*task_type=*/std::nullopt, - /*device_coord=*/std::nullopt, - /*mapping=*/mapped_task_group, - /*op_attrs=*/std::nullopt, - /*layer_guid=*/ - dynamic_layer_guid_t{parallel_layer_guid_t{Node{20}}}, - /*per_device_op_state=*/std::nullopt, + auto mk_sharding_info = [&](TensorSlotName slot_name, + ParallelTensorSpaceCoordinate const &shard_coord, + MappedOperatorTaskGroup const &mapped_op_task_group, + MachineSpaceCoordinate const &device_coord) + -> std::pair + { + OneToMany + tensor_binding = get_tensor_bindings_for_slot_name(mapped_op_task_group, + slot_name); + return std::pair{ + mk_slot(slot_name), + DynamicValueAttrsShardingInfo{ + /*shard_coord=*/shard_coord, + /*mapping=*/one_to_many_filter_values(tensor_binding, + [&](MachineSpaceCoordinate const &c) -> bool { + return device_coord == c; + }), }, - /*outputs=*/ - { - { - mk_slot(TensorSlotName::OUTPUT_1), - mk_op_value(20, - TensorSlotName::OUTPUT_1, - TensorSlotName::OUTPUT_1, - std::nullopt), - }, - { - mk_slot(TensorSlotName::OUTPUT_2), - mk_op_value(20, - TensorSlotName::OUTPUT_2, - TensorSlotName::OUTPUT_2, - std::nullopt), + }; + }; + + SUBCASE("standard operator") { + MachineSpaceCoordinate mc1 = mk_machine_coord(0_n, 0_n); + MachineSpaceCoordinate mc2 = mk_machine_coord(2_n, 0_n); + + auto mk_shard_binding = [&](ParallelTensorSpaceCoordinate const &c1, + ParallelTensorSpaceCoordinate const &c2, + ParallelTensorSpaceCoordinate const &c3, + ParallelTensorSpaceCoordinate const &c4) + -> OperatorAtomicTaskShardBinding { + return OperatorAtomicTaskShardBinding{ + /*tensor_coords=*/{ + { + TensorSlotName::INPUT, + c1, + }, + { + TensorSlotName::WEIGHT, + c2, + }, + { + TensorSlotName::OUTPUT_1, + c3, + }, + { + TensorSlotName::OUTPUT_2, + c4, + }, }, + }; + }; + + ParallelTensorSpaceCoordinate mc1_input_coord = + mk_pt_coord(0_n, 0_n, 0_n, 0_n); + ParallelTensorSpaceCoordinate mc1_weight_coord = + mk_pt_coord(0_n, 1_n, 2_n, 0_n); + ParallelTensorSpaceCoordinate mc1_output_1_coord = + mk_pt_coord(1_n, 0_n, 0_n, 1_n); + ParallelTensorSpaceCoordinate mc1_output_2_coord = + mk_pt_coord(3_n, 0_n, 0_n, 0_n); + + ParallelTensorSpaceCoordinate mc2_input_coord = + mk_pt_coord(0_n, 1_n, 0_n, 0_n); + ParallelTensorSpaceCoordinate mc2_weight_coord = + mk_pt_coord(0_n, 4_n, 2_n, 0_n); + ParallelTensorSpaceCoordinate mc2_output_1_coord = + mk_pt_coord(1_n, 2_n, 0_n, 1_n); + ParallelTensorSpaceCoordinate mc2_output_2_coord = + mk_pt_coord(0_n, 0_n, 0_n, 0_n); + + TrainingOperationAttrs op_attrs = TrainingOperationAttrs{ + PCGOperatorAttrs{ + make_relu_attrs(), }, - }; + }; + + MappedOperatorTaskGroup mapped_task_group = MappedOperatorTaskGroup{ + bidict{ + { + mc1, + mk_shard_binding(mc1_input_coord, + mc1_weight_coord, + mc1_output_1_coord, + mc1_output_2_coord), + }, + { + mc2, + mk_shard_binding(mc2_input_coord, + mc2_weight_coord, + mc2_output_1_coord, + mc2_output_2_coord), + }, + }, + }; - std::unordered_set result = - perform_shard_expansion_for_invocation(input); - - auto mk_invocation_shard = - [&](MachineSpaceCoordinate const &device_coord, - ParallelTensorSpaceCoordinate const &input_shard_coord, - ParallelTensorSpaceCoordinate const &weight_shard_coord, - ParallelTensorSpaceCoordinate const &output_1_shard_coord, - ParallelTensorSpaceCoordinate const &output_2_shard_coord) - -> DynamicNodeInvocation { - return DynamicNodeInvocation{ + DynamicNodeInvocation input = DynamicNodeInvocation{ /*inputs=*/{ { mk_slot(TensorSlotName::INPUT), mk_op_value(0, TensorSlotName::OUTPUT, TensorSlotName::INPUT, - input_shard_coord), + mapped_task_group, + std::nullopt), }, { mk_slot(TensorSlotName::WEIGHT), mk_op_value(1, TensorSlotName::OUTPUT, TensorSlotName::WEIGHT, - weight_shard_coord), + mapped_task_group, + std::nullopt), }, }, /*node_attrs=*/ DynamicNodeAttrs{ /*task_type=*/std::nullopt, - /*device_coord=*/device_coord, + /*device_coord=*/std::nullopt, /*mapping=*/mapped_task_group, - /*op_attrs=*/std::nullopt, + /*op_attrs=*/op_attrs, /*layer_guid=*/ dynamic_layer_guid_t{parallel_layer_guid_t{Node{20}}}, /*per_device_op_state=*/std::nullopt, @@ -239,112 +220,152 @@ TEST_SUITE(FF_TEST_SUITE) { mk_op_value(20, TensorSlotName::OUTPUT_1, TensorSlotName::OUTPUT_1, - output_1_shard_coord), + mapped_task_group, + std::nullopt), }, { mk_slot(TensorSlotName::OUTPUT_2), mk_op_value(20, TensorSlotName::OUTPUT_2, TensorSlotName::OUTPUT_2, - output_2_shard_coord), + mapped_task_group, + std::nullopt), }, }, }; - }; - std::unordered_set correct = { - mk_invocation_shard(mc1, - mc1_input_coord, - mc1_weight_coord, - mc1_output_1_coord, - mc1_output_2_coord), - mk_invocation_shard(mc2, - mc2_input_coord, - mc2_weight_coord, - mc2_output_1_coord, - mc2_output_2_coord), - }; + std::unordered_set result = + generate_shard_expansion_for_invocation(input); - CHECK(result.size() == correct.size()); - CHECK(result == correct); - } + auto mk_invocation_shard = + [&](MachineSpaceCoordinate const &device_coord, + ParallelTensorSpaceCoordinate const &input_shard_coord, + ParallelTensorSpaceCoordinate const &weight_shard_coord, + ParallelTensorSpaceCoordinate const &output_1_shard_coord, + ParallelTensorSpaceCoordinate const &output_2_shard_coord) + -> DynamicNodeInvocationShardingInfo { + return DynamicNodeInvocationShardingInfo{ + /*device_coord=*/device_coord, + /*value_sharding=*/{ + mk_sharding_info(TensorSlotName::INPUT, input_shard_coord, mapped_task_group, device_coord), + mk_sharding_info(TensorSlotName::WEIGHT, weight_shard_coord, mapped_task_group, device_coord), + mk_sharding_info(TensorSlotName::OUTPUT_1, output_1_shard_coord, mapped_task_group, device_coord), + mk_sharding_info(TensorSlotName::OUTPUT_2, output_2_shard_coord, mapped_task_group, device_coord), + }, + }; + }; - TEST_CASE("perform_shard_expansion_for_invocation (copy)") { - MachineSpaceCoordinate mc1 = mk_machine_coord(0_n, 0_n); - MachineSpaceCoordinate mc2 = mk_machine_coord(1_n, 0_n); - MachineSpaceCoordinate mc3 = mk_machine_coord(2_n, 0_n); - MachineSpaceCoordinate mc4 = mk_machine_coord(3_n, 0_n); + std::unordered_set correct = { + mk_invocation_shard(mc1, + mc1_input_coord, + mc1_weight_coord, + mc1_output_1_coord, + mc1_output_2_coord), + mk_invocation_shard(mc2, + mc2_input_coord, + mc2_weight_coord, + mc2_output_1_coord, + mc2_output_2_coord), + }; - ParallelTensorSpaceCoordinate pt1 = mk_pt_coord(0_n, 0_n, 0_n, 0_n); - ParallelTensorSpaceCoordinate pt2 = mk_pt_coord(0_n, 1_n, 0_n, 0_n); + nlohmann::json result_json = result; + nlohmann::json correct_json = correct; - bidict src_binding{ - {pt1, mc1}, - {pt2, mc2}, - }; - bidict dst_binding{ - {pt1, mc3}, - {pt2, mc4}, - }; + CHECK(result.size() == correct.size()); + CHECK(result_json == correct_json); + CHECK(result == correct); + } - DynamicNodeInvocation input = DynamicNodeInvocation{ - /*inputs=*/{ - { - mk_slot(TensorSlotName::INPUT), - mk_value(0, TensorSlotName::OUTPUT, src_binding, std::nullopt), - }, - }, - /*node_attrs=*/ - DynamicNodeAttrs{ - /*task_type=*/std::nullopt, - /*device_coord=*/std::nullopt, - /*mapping=*/std::nullopt, - /*op_attrs=*/TrainingOperationAttrs{CopyAttrs{}}, - /*layer_guid=*/dynamic_layer_guid_t{dynamic_copy_layer_guid_t{}}, - /*per_device_op_state=*/std::nullopt, - }, - /*outputs=*/ - { - { - mk_slot(TensorSlotName::OUTPUT), - mk_value(20, TensorSlotName::OUTPUT, dst_binding, std::nullopt), - }, - }, - }; + SUBCASE("copy operator") { + MachineSpaceCoordinate mc1 = mk_machine_coord(0_n, 0_n); + MachineSpaceCoordinate mc2 = mk_machine_coord(1_n, 0_n); + MachineSpaceCoordinate mc3 = mk_machine_coord(2_n, 0_n); + MachineSpaceCoordinate mc4 = mk_machine_coord(3_n, 0_n); + + ParallelTensorSpaceCoordinate pt1 = mk_pt_coord(0_n, 0_n, 0_n, 0_n); + ParallelTensorSpaceCoordinate pt2 = mk_pt_coord(0_n, 1_n, 0_n, 0_n); - std::unordered_set result = - perform_shard_expansion_for_invocation(input); + OneToMany src_binding{ + {pt1, {mc1}}, + {pt2, {mc2}}, + }; + + OneToMany dst_binding{ + {pt1, {mc3}}, + {pt2, {mc4}}, + }; - auto mk_invocation_shard = - [&](MachineSpaceCoordinate const &device_coord, - ParallelTensorSpaceCoordinate const &tensor_shard_coord) - -> DynamicNodeInvocation { - DynamicNodeInvocation result = input; - result.inputs = { + DynamicNodeInvocation input = DynamicNodeInvocation{ + /*inputs=*/{ + { + mk_slot(TensorSlotName::INPUT), + mk_value(0, TensorSlotName::OUTPUT, src_binding, std::nullopt), + }, + }, + /*node_attrs=*/ + DynamicNodeAttrs{ + /*task_type=*/std::nullopt, + /*device_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*op_attrs=*/TrainingOperationAttrs{CopyAttrs{}}, + /*layer_guid=*/dynamic_layer_guid_t{dynamic_copy_layer_guid_t{}}, + /*per_device_op_state=*/std::nullopt, + }, + /*outputs=*/ { - mk_slot(TensorSlotName::INPUT), - mk_value( - 0, TensorSlotName::OUTPUT, src_binding, tensor_shard_coord), + { + mk_slot(TensorSlotName::OUTPUT), + mk_value(20, TensorSlotName::OUTPUT, dst_binding, std::nullopt), + }, }, }; - // See perform_shard_expansion_for_copy in shard_expansion.cc for explanation of the choice of device placement. - result.node_attrs.device_coord = device_coord; - result.outputs = { - { + + std::unordered_set result = + generate_shard_expansion_for_invocation(input); + + auto mk_invocation_shard = + [&](MachineSpaceCoordinate const &device_coord, + ParallelTensorSpaceCoordinate const &tensor_shard_coord) + -> DynamicNodeInvocationShardingInfo { + + return DynamicNodeInvocationShardingInfo{ + /*device_coord=*/device_coord, + /*value_sharding=*/std::map{ + { + mk_slot(TensorSlotName::INPUT), + DynamicValueAttrsShardingInfo{ + tensor_shard_coord, + one_to_many_filter_keys(src_binding, + [&](ParallelTensorSpaceCoordinate const &pt_coord) -> bool { + return pt_coord == tensor_shard_coord; + }), + }, + }, + { mk_slot(TensorSlotName::OUTPUT), - mk_value( - 20, TensorSlotName::OUTPUT, dst_binding, tensor_shard_coord), + DynamicValueAttrsShardingInfo{ + tensor_shard_coord, + one_to_many_filter_keys(dst_binding, + [&](ParallelTensorSpaceCoordinate const &pt_coord) -> bool { + return pt_coord == tensor_shard_coord; + }), + }, + }, }, + }; }; - return result; - }; - std::unordered_set correct = { - mk_invocation_shard(mc1, pt1), - mk_invocation_shard(mc2, pt2), - }; + std::unordered_set correct = { + mk_invocation_shard(mc1, pt1), + mk_invocation_shard(mc2, pt2), + }; + + nlohmann::json result_json = result; + nlohmann::json correct_json = correct; - CHECK(result.size() == correct.size()); - CHECK(result == correct); + CHECK(result.size() == correct.size()); + CHECK(result_json == correct_json); + CHECK(result == correct); + } } } diff --git a/lib/utils/include/utils/archetypes/jsonable_ordered_value_type.h b/lib/utils/include/utils/archetypes/jsonable_ordered_value_type.h new file mode 100644 index 0000000000..ad43cd52b6 --- /dev/null +++ b/lib/utils/include/utils/archetypes/jsonable_ordered_value_type.h @@ -0,0 +1,91 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ARCHETYPES_JSONABLE_ORDERED_VALUE_TYPE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ARCHETYPES_JSONABLE_ORDERED_VALUE_TYPE_H + +#include +#include +#include +#include +#include +#include + +namespace FlexFlow { + +template +struct jsonable_ordered_value_type { + jsonable_ordered_value_type() = delete; + + jsonable_ordered_value_type(jsonable_ordered_value_type const &) { + PANIC(); + } + jsonable_ordered_value_type &operator=(jsonable_ordered_value_type const &) { + PANIC(); + } + + jsonable_ordered_value_type(jsonable_ordered_value_type &&) { + PANIC(); + } + jsonable_ordered_value_type &operator=(jsonable_ordered_value_type &&) { + PANIC(); + } + + bool operator==(jsonable_ordered_value_type const &) const { + PANIC(); + } + + bool operator!=(jsonable_ordered_value_type const &) const { + PANIC(); + } + + bool operator<(jsonable_ordered_value_type const &) const { + PANIC(); + } + bool operator>(jsonable_ordered_value_type const &) const { + PANIC(); + } + bool operator<=(jsonable_ordered_value_type const &) const { + PANIC(); + } + bool operator>=(jsonable_ordered_value_type const &) const { + PANIC(); + } +}; + +template +std::string format_as(jsonable_ordered_value_type const &) { + PANIC(); +} + +template +std::ostream &operator<<(std::ostream &s, jsonable_ordered_value_type const &x) { + PANIC(); +} + +} // namespace FlexFlow + +namespace nlohmann { + +template +struct adl_serializer<::FlexFlow::jsonable_ordered_value_type> { + static ::FlexFlow::jsonable_ordered_value_type from_json(json const &) { + PANIC(); + } + + static void to_json(json &, ::FlexFlow::jsonable_ordered_value_type const &) { + PANIC(); + } +}; + +} // namespace nlohmann + +namespace std { + +template +struct hash<::FlexFlow::jsonable_ordered_value_type> { + size_t operator()(::FlexFlow::jsonable_ordered_value_type const &) const { + PANIC(); + }; +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/bidict/algorithms/filter_keys.h b/lib/utils/include/utils/bidict/algorithms/bidict_filter_keys.h similarity index 53% rename from lib/utils/include/utils/bidict/algorithms/filter_keys.h rename to lib/utils/include/utils/bidict/algorithms/bidict_filter_keys.h index 2734dfaeb5..4c2ceb840b 100644 --- a/lib/utils/include/utils/bidict/algorithms/filter_keys.h +++ b/lib/utils/include/utils/bidict/algorithms/bidict_filter_keys.h @@ -1,12 +1,12 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_FILTER_KEYS_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_FILTER_KEYS_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_FILTER_KEYS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_FILTER_KEYS_H #include "utils/bidict/bidict.h" namespace FlexFlow { template -bidict filter_keys(bidict const &m, F &&f) { +bidict bidict_filter_keys(bidict const &m, F &&f) { bidict result; for (auto const &kv : m) { if (f(kv.first)) { diff --git a/lib/utils/include/utils/bidict/algorithms/filter_values.h b/lib/utils/include/utils/bidict/algorithms/bidict_filter_values.h similarity index 53% rename from lib/utils/include/utils/bidict/algorithms/filter_values.h rename to lib/utils/include/utils/bidict/algorithms/bidict_filter_values.h index 5817578e79..cb968f2d02 100644 --- a/lib/utils/include/utils/bidict/algorithms/filter_values.h +++ b/lib/utils/include/utils/bidict/algorithms/bidict_filter_values.h @@ -1,12 +1,12 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_FILTER_VALUES_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_FILTER_VALUES_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_FILTER_VALUES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_FILTER_VALUES_H #include "utils/bidict/bidict.h" namespace FlexFlow { template -bidict filter_values(bidict const &m, F &&f) { +bidict bidict_filter_values(bidict const &m, F &&f) { bidict result; for (auto const &kv : m) { if (f(kv.second)) { diff --git a/lib/utils/include/utils/bidict/algorithms/filtrans_keys.h b/lib/utils/include/utils/bidict/algorithms/bidict_filtrans_keys.h similarity index 64% rename from lib/utils/include/utils/bidict/algorithms/filtrans_keys.h rename to lib/utils/include/utils/bidict/algorithms/bidict_filtrans_keys.h index df6495b400..bd9018bd38 100644 --- a/lib/utils/include/utils/bidict/algorithms/filtrans_keys.h +++ b/lib/utils/include/utils/bidict/algorithms/bidict_filtrans_keys.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_FILTRANS_KEYS_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_FILTRANS_KEYS_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_FILTRANS_KEYS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_FILTRANS_KEYS_H #include "utils/bidict/bidict.h" @@ -9,7 +9,7 @@ template ::value_type> -bidict filtrans_keys(bidict const &m, F &&f) { +bidict bidict_filtrans_keys(bidict const &m, F &&f) { bidict result; for (auto const &[k, v] : m) { std::optional new_k = f(k); diff --git a/lib/utils/include/utils/bidict/algorithms/filtrans_values.h b/lib/utils/include/utils/bidict/algorithms/bidict_filtrans_values.h similarity index 63% rename from lib/utils/include/utils/bidict/algorithms/filtrans_values.h rename to lib/utils/include/utils/bidict/algorithms/bidict_filtrans_values.h index 11180938b8..592440f7f6 100644 --- a/lib/utils/include/utils/bidict/algorithms/filtrans_values.h +++ b/lib/utils/include/utils/bidict/algorithms/bidict_filtrans_values.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_FILTRANS_VALUES_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_FILTRANS_VALUES_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_FILTRANS_VALUES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_FILTRANS_VALUES_H #include "utils/bidict/bidict.h" @@ -9,7 +9,7 @@ template ::value_type> -bidict filtrans_values(bidict const &m, F &&f) { +bidict bidict_filtrans_values(bidict const &m, F &&f) { bidict result; for (auto const &[k, v] : m) { std::optional new_v = f(v); diff --git a/lib/utils/include/utils/bidict/algorithms/unordered_set_of.h b/lib/utils/include/utils/bidict/algorithms/bidict_unordered_set_of.h similarity index 51% rename from lib/utils/include/utils/bidict/algorithms/unordered_set_of.h rename to lib/utils/include/utils/bidict/algorithms/bidict_unordered_set_of.h index b3df2514cf..251573d441 100644 --- a/lib/utils/include/utils/bidict/algorithms/unordered_set_of.h +++ b/lib/utils/include/utils/bidict/algorithms/bidict_unordered_set_of.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_UNORDERED_SET_OF_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_UNORDERED_SET_OF_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_UNORDERED_SET_OF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_UNORDERED_SET_OF_H #include "utils/bidict/bidict.h" #include "utils/hash/pair.h" @@ -7,7 +7,7 @@ namespace FlexFlow { template -std::unordered_set> unordered_set_of(bidict const &c) { +std::unordered_set> bidict_unordered_set_of(bidict const &c) { std::unordered_set> result; for (auto const &lr : c) { diff --git a/lib/utils/include/utils/bidict/algorithms/transform_keys.h b/lib/utils/include/utils/bidict/algorithms/transform_keys.h index 8ecb10c401..1d82464d17 100644 --- a/lib/utils/include/utils/bidict/algorithms/transform_keys.h +++ b/lib/utils/include/utils/bidict/algorithms/transform_keys.h @@ -12,7 +12,7 @@ template transform_keys(bidict const &m, F &&f) { bidict result; for (auto const &kv : m) { - result.equate(f(kv.first), kv.second); + result.equate_strict(f(kv.first), kv.second); } return result; } diff --git a/lib/utils/include/utils/bidict/algorithms/transform_values.h b/lib/utils/include/utils/bidict/algorithms/transform_values.h index ef5b34ebe9..fc8655594e 100644 --- a/lib/utils/include/utils/bidict/algorithms/transform_values.h +++ b/lib/utils/include/utils/bidict/algorithms/transform_values.h @@ -12,7 +12,7 @@ template transform_values(bidict const &m, F &&f) { bidict result; for (auto const &kv : m) { - result.equate({kv.first, f(kv.second)}); + result.equate_strict({kv.first, f(kv.second)}); } return result; } diff --git a/lib/utils/include/utils/bidict/algorithms/unstructured_relation_from_bidict.h b/lib/utils/include/utils/bidict/algorithms/unstructured_relation_from_bidict.h index 2ceb527b96..63d25332d3 100644 --- a/lib/utils/include/utils/bidict/algorithms/unstructured_relation_from_bidict.h +++ b/lib/utils/include/utils/bidict/algorithms/unstructured_relation_from_bidict.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_UNSTRUCTURED_RELATION_FROM_BIDICT_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_UNSTRUCTURED_RELATION_FROM_BIDICT_H -#include "utils/bidict/algorithms/unordered_set_of.h" +#include "utils/bidict/algorithms/bidict_unordered_set_of.h" #include "utils/bidict/bidict.h" namespace FlexFlow { @@ -9,7 +9,7 @@ namespace FlexFlow { template std::unordered_set> unstructured_relation_from_bidict(bidict const &b) { - return unordered_set_of(b); + return bidict_unordered_set_of(b); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/bidict/bidict.h b/lib/utils/include/utils/bidict/bidict.h index 2d8c5d23a8..57f8d5e213 100644 --- a/lib/utils/include/utils/bidict/bidict.h +++ b/lib/utils/include/utils/bidict/bidict.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_BIDICT_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_BIDICT_H -#include "utils/containers/keys.h" +#include "utils/containers/unordered_keys.h" #include "utils/containers/map_from_keys_and_values.h" #include "utils/fmt/unordered_map.h" #include "utils/hash/unordered_map.h" @@ -13,6 +13,9 @@ #include #include #include +#include "utils/containers/require_same.h" +#include "utils/containers/values.h" +#include "utils/containers/unordered_set_of.h" namespace FlexFlow { @@ -65,11 +68,15 @@ struct bidict { void equate(L const &l, R const &r) { fwd_map.insert({l, r}); bwd_map.insert({r, l}); + + this->check_invariants(); } void equate(std::pair const &lr) { fwd_map.insert(lr); bwd_map.insert({lr.second, lr.first}); + + this->check_invariants(); } void equate_strict(L const &l, R const &r) { @@ -87,15 +94,17 @@ struct bidict { } bool operator==(bidict const &other) const { - bool result = this->fwd_map == other.fwd_map; - assert(result == (this->bwd_map == other.bwd_map)); - return result; + return require_same( + (this->fwd_map == other.fwd_map), + (this->bwd_map == other.bwd_map) + ); } bool operator!=(bidict const &other) const { - bool result = this->fwd_map != other.fwd_map; - assert(result == (this->bwd_map != other.bwd_map)); - return result; + return require_same( + (this->fwd_map != other.fwd_map), + (this->bwd_map != other.bwd_map) + ); } R const &at_l(L const &l) const { @@ -107,11 +116,11 @@ struct bidict { } std::unordered_set left_values() const { - return keys(this->fwd_map); + return unordered_keys(this->fwd_map); } std::unordered_set right_values() const { - return keys(this->bwd_map); + return unordered_keys(this->bwd_map); } std::size_t size() const { @@ -226,6 +235,21 @@ struct bidict { : fwd_map(fwd_map), bwd_map(bwd_map) {} private: + void check_invariants() const { + std::unordered_set fwd_l_vals = unordered_keys(this->fwd_map); + std::unordered_set bwd_l_vals = unordered_set_of(values(this->bwd_map)); + + std::unordered_set bwd_r_vals = unordered_keys(this->bwd_map); + std::unordered_set fwd_r_vals = unordered_set_of(values(this->fwd_map)); + + ASSERT(fwd_l_vals == bwd_l_vals); + ASSERT(fwd_r_vals == bwd_r_vals); + + for (L const &l : fwd_l_vals) { + ASSERT(bwd_map.at(fwd_map.at(l)) == l); + } + } + friend struct bidict; std::unordered_map fwd_map; diff --git a/lib/utils/include/utils/containers/all_of.h b/lib/utils/include/utils/containers/all_of.h index ef5aac1c41..15ed234511 100644 --- a/lib/utils/include/utils/containers/all_of.h +++ b/lib/utils/include/utils/containers/all_of.h @@ -8,7 +8,7 @@ namespace FlexFlow { template -bool all_of(C const &c, F &&f) { +[[nodiscard]] bool all_of(C const &c, F &&f) { for (auto const &v : c) { if (!f(v)) { return false; @@ -18,7 +18,7 @@ bool all_of(C const &c, F &&f) { } template -bool all_of(std::unordered_map const &m, F &&f) { +[[nodiscard]] bool all_of(std::unordered_map const &m, F &&f) { for (auto const &[k, v] : m) { if (!f(k, v)) { return false; @@ -29,7 +29,7 @@ bool all_of(std::unordered_map const &m, F &&f) { } template -bool all_of(std::map const &m, F &&f) { +[[nodiscard]] bool all_of(std::map const &m, F &&f) { for (auto const &[k, v] : m) { if (!f(k, v)) { return false; @@ -39,7 +39,7 @@ bool all_of(std::map const &m, F &&f) { return true; } -bool all_of(std::vector const &); +[[nodiscard]] bool all_of(std::vector const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/containers/binary_merge_disjoint_maps.h b/lib/utils/include/utils/containers/binary_merge_disjoint_maps.h index 06a42327e1..824fe77b39 100644 --- a/lib/utils/include/utils/containers/binary_merge_disjoint_maps.h +++ b/lib/utils/include/utils/containers/binary_merge_disjoint_maps.h @@ -11,8 +11,8 @@ std::unordered_map binary_merge_disjoint_maps(std::unordered_map const &lhs, std::unordered_map const &rhs) { - std::unordered_set lhs_keys = keys(lhs); - std::unordered_set rhs_keys = keys(rhs); + std::unordered_set lhs_keys = unordered_keys(lhs); + std::unordered_set rhs_keys = unordered_keys(rhs); std::unordered_set shared_keys = intersection(lhs_keys, rhs_keys); ASSERT(shared_keys.empty()); diff --git a/lib/utils/include/utils/containers/binary_merge_maps_with.h b/lib/utils/include/utils/containers/binary_merge_maps_with.h index a7c196d061..2d0b57eb81 100644 --- a/lib/utils/include/utils/containers/binary_merge_maps_with.h +++ b/lib/utils/include/utils/containers/binary_merge_maps_with.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_BINARY_MERGE_MAPS_WITH_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_BINARY_MERGE_MAPS_WITH_H -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/containers/intersection.h" -#include "utils/containers/keys.h" +#include "utils/containers/unordered_keys.h" #include "utils/containers/merge_maps_with_right_dominating.h" #include "utils/containers/restrict_keys.h" #include "utils/containers/set_minus.h" @@ -17,8 +17,8 @@ std::unordered_map std::unordered_map const &rhs, F &&f) { - std::unordered_set l_keys = keys(lhs); - std::unordered_set r_keys = keys(rhs); + std::unordered_set l_keys = unordered_keys(lhs); + std::unordered_set r_keys = unordered_keys(rhs); std::unordered_set l_only_keys = set_minus(l_keys, r_keys); std::unordered_set r_only_keys = set_minus(r_keys, l_keys); @@ -27,7 +27,7 @@ std::unordered_map std::unordered_map l_only = restrict_keys(lhs, l_only_keys); std::unordered_map r_only = restrict_keys(rhs, r_only_keys); - std::unordered_map merged = generate_map( + std::unordered_map merged = generate_unordered_map( both_keys, [&](K const &k) { return f(lhs.at(k), rhs.at(k)); }); return merge_maps_with_right_dominating(std::vector{ diff --git a/lib/utils/include/utils/containers/filter.h b/lib/utils/include/utils/containers/filter.h index 07f25dc348..85a413c2c7 100644 --- a/lib/utils/include/utils/containers/filter.h +++ b/lib/utils/include/utils/containers/filter.h @@ -44,6 +44,13 @@ std::map filter(std::map const &m, F const &f) { return result; } +template +std::multiset filter(std::multiset const &m, F const &f) { + std::multiset result; + std::copy_if(m.cbegin(), m.cend(), std::inserter(result, result.begin()), f); + return result; +} + template std::unordered_multiset filter(std::unordered_multiset const &m, F const &f) { diff --git a/lib/utils/include/utils/containers/generate_map.h b/lib/utils/include/utils/containers/generate_map.h index 53b2a590c5..08bfc86350 100644 --- a/lib/utils/include/utils/containers/generate_map.h +++ b/lib/utils/include/utils/containers/generate_map.h @@ -5,7 +5,7 @@ #include "utils/containers/vector_of.h" #include "utils/containers/vector_transform.h" #include "utils/type_traits_core.h" -#include +#include namespace FlexFlow { @@ -13,8 +13,8 @@ template , typename V = std::invoke_result_t> -std::unordered_map generate_map(C const &c, F const &f) { - static_assert(is_hashable_v, "Key type should be hashable (but is not)"); +std::map generate_map(C const &c, F &&f) { + static_assert(is_lt_comparable_v, "Key type should be ordered (but is not)"); auto transformed = vector_transform(vector_of(c), [&](K const &k) -> std::pair { diff --git a/lib/utils/include/utils/containers/generate_unordered_map.h b/lib/utils/include/utils/containers/generate_unordered_map.h new file mode 100644 index 0000000000..d57632b7ae --- /dev/null +++ b/lib/utils/include/utils/containers/generate_unordered_map.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GENERATE_UNORDERED_MAP_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GENERATE_UNORDERED_MAP_H + +#include "utils/containers/get_element_type.h" +#include "utils/containers/vector_of.h" +#include "utils/containers/vector_transform.h" +#include "utils/type_traits_core.h" +#include + +namespace FlexFlow { + +template , + typename V = std::invoke_result_t> +std::unordered_map generate_unordered_map(C const &c, F &&f) { + static_assert(is_hashable_v, "Key type should be hashable (but is not)"); + + auto transformed = + vector_transform(vector_of(c), [&](K const &k) -> std::pair { + return {k, f(k)}; + }); + return {transformed.cbegin(), transformed.cend()}; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/get_all_assignments.h b/lib/utils/include/utils/containers/get_all_assignments.h index 9981948f47..8f77ffbc24 100644 --- a/lib/utils/include/utils/containers/get_all_assignments.h +++ b/lib/utils/include/utils/containers/get_all_assignments.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_ALL_ASSIGNMENTS_H #include "utils/containers/cartesian_product.h" -#include "utils/containers/keys.h" +#include "utils/containers/unordered_keys.h" #include "utils/containers/transform.h" #include "utils/containers/unordered_map_from_pairs.h" #include "utils/containers/unordered_set_of.h" @@ -26,7 +26,7 @@ std::unordered_set> get_all_assignments( return {{}}; } - std::vector ordered_keys = vector_of(keys(options_per_key)); + std::vector ordered_keys = vector_of(unordered_keys(options_per_key)); std::vector> ordered_value_option_sets = transform( ordered_keys, [&](K const &k) { return options_per_key.at(k); }); diff --git a/lib/utils/include/utils/containers/index.dox b/lib/utils/include/utils/containers/index.dox index 9b3865dd78..2ffda1cfdf 100644 --- a/lib/utils/include/utils/containers/index.dox +++ b/lib/utils/include/utils/containers/index.dox @@ -9,7 +9,7 @@ Some of the most commonly-used functions are listed below, but you should ideall - \ref containers/transform.h - \ref containers/filter.h - \ref containers/contains.h -- \ref containers/generate_map.h +- \ref containers/generate_unordered_map.h - \ref containers/get_only.h - \ref containers/slice.h - \ref containers/merge_disjoint_maps.h diff --git a/lib/utils/include/utils/containers/is_submapeq_of.h b/lib/utils/include/utils/containers/is_submapeq_of.h index 03cb5ccd78..e50e85e745 100644 --- a/lib/utils/include/utils/containers/is_submapeq_of.h +++ b/lib/utils/include/utils/containers/is_submapeq_of.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_IS_SUBMAP_H #define _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_IS_SUBMAP_H -#include "utils/containers/keys.h" +#include "utils/containers/unordered_keys.h" #include "utils/containers/restrict_keys.h" #include @@ -10,7 +10,7 @@ namespace FlexFlow { template bool is_submapeq_of(std::unordered_map const &sub, std::unordered_map const &m) { - return restrict_keys(m, keys(sub)) == sub; + return restrict_keys(m, unordered_keys(sub)) == sub; } } // namespace FlexFlow diff --git a/lib/utils/include/utils/containers/items.h b/lib/utils/include/utils/containers/items.h index 8e3ba95d6c..13e745b17e 100644 --- a/lib/utils/include/utils/containers/items.h +++ b/lib/utils/include/utils/containers/items.h @@ -1,12 +1,12 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ITEMS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ITEMS_H -#include +#include namespace FlexFlow { template -std::unordered_set> +std::set> items(C const &c) { return {c.begin(), c.end()}; } diff --git a/lib/utils/include/utils/containers/keys.h b/lib/utils/include/utils/containers/keys.h index e14612541e..bd080b7087 100644 --- a/lib/utils/include/utils/containers/keys.h +++ b/lib/utils/include/utils/containers/keys.h @@ -3,13 +3,13 @@ #include #include -#include +#include namespace FlexFlow { template -std::unordered_set keys(std::unordered_map const &c) { - std::unordered_set result; +std::set keys(std::unordered_map const &c) { + std::set result; for (auto const &kv : c) { result.insert(kv.first); } @@ -17,8 +17,8 @@ std::unordered_set keys(std::unordered_map const &c) { } template -std::unordered_set keys(std::map const &c) { - std::unordered_set result; +std::set keys(std::map const &c) { + std::set result; for (auto const &kv : c) { result.insert(kv.first); } diff --git a/lib/utils/include/utils/containers/lookup_in_map.h b/lib/utils/include/utils/containers/lookup_in_map.h index 946fc589db..339b13f042 100644 --- a/lib/utils/include/utils/containers/lookup_in_map.h +++ b/lib/utils/include/utils/containers/lookup_in_map.h @@ -1,24 +1,22 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_LOOKUP_IN_MAP_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_LOOKUP_IN_MAP_H -#include "utils/containers/contains.h" -#include "utils/containers/keys.h" -#include "utils/exception.h" #include "utils/fmt/unordered_map.h" +#include "utils/containers/contains_key.h" #include #include #include +#include namespace FlexFlow { template -std::function lookup_in_map(std::unordered_map const &map) { - return [map](K const &key) -> V { - if (!contains(keys(map), key)) { - throw mk_runtime_error(fmt::format( - "Key {} is not present in the underlying map {}", key, map)); +std::function lookup_in_map(std::unordered_map const &m) { + return [m](K const &key) -> V { + if (!contains_key(m, key)) { + PANIC("Key {} is not present in the underlying map {}", key, m); } - return map.at(key); + return m.at(key); }; } diff --git a/lib/utils/include/utils/containers/map_from_unordered.h b/lib/utils/include/utils/containers/map_from_unordered.h new file mode 100644 index 0000000000..1451cd1aa8 --- /dev/null +++ b/lib/utils/include/utils/containers/map_from_unordered.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAP_FROM_UNORDERED_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAP_FROM_UNORDERED_H + +#include +#include + +namespace FlexFlow { + +template +std::map map_from_unordered(std::unordered_map const &u) { + std::map result{u.cbegin(), u.cend()}; + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/map_keys2.h b/lib/utils/include/utils/containers/map_keys2.h index fd848f18d8..da68fe05b4 100644 --- a/lib/utils/include/utils/containers/map_keys2.h +++ b/lib/utils/include/utils/containers/map_keys2.h @@ -19,7 +19,7 @@ std::unordered_map map_keys2(std::unordered_map const &m, result.insert({f(kv.first, kv.second), kv.second}); } - ASSERT(keys(m).size() == keys(result).size(), + ASSERT(m.size() == result.size(), "keys passed to map_keys must be transformed into distinct keys"); return result; diff --git a/lib/utils/include/utils/containers/map_keys_and_values.h b/lib/utils/include/utils/containers/map_keys_and_values.h index 651ffb2aeb..70b7e17103 100644 --- a/lib/utils/include/utils/containers/map_keys_and_values.h +++ b/lib/utils/include/utils/containers/map_keys_and_values.h @@ -1,7 +1,6 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAP_KEYS_AND_VALUES_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAP_KEYS_AND_VALUES_H -#include "utils/containers/keys.h" #include #include @@ -21,7 +20,7 @@ std::unordered_map map_keys_and_values( result.insert({fk(kv.first), fv(kv.second)}); } - ASSERT(keys(m).size() == keys(result).size(), + ASSERT(m.size() == result.size(), "keys passed to map_keys must be transformed into distinct keys"); return result; diff --git a/lib/utils/include/utils/containers/map_values.h b/lib/utils/include/utils/containers/map_values.h index bf377b2c93..575fff977e 100644 --- a/lib/utils/include/utils/containers/map_values.h +++ b/lib/utils/include/utils/containers/map_values.h @@ -3,6 +3,7 @@ #include #include +#include namespace FlexFlow { @@ -18,6 +19,18 @@ std::unordered_map map_values(std::unordered_map const &m, F &&f) { return result; } +template > +std::map map_values(std::map const &m, F &&f) { + std::map result; + for (std::pair const &kv : m) { + result.insert(std::pair{kv.first, f(kv.second)}); + } + return result; +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/require_all_of.h b/lib/utils/include/utils/containers/require_all_of.h new file mode 100644 index 0000000000..c085161659 --- /dev/null +++ b/lib/utils/include/utils/containers/require_all_of.h @@ -0,0 +1,33 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REQUIRE_ALL_OF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REQUIRE_ALL_OF_H + +#include +#include +#include + +namespace FlexFlow { + +template +void require_all_of(C const &c, F &&f) { + for (auto const &v : c) { + f(v); + } +} + +template +void require_all_of(std::unordered_map const &m, F &&f) { + for (auto const &[k, v] : m) { + f(k, v); + } +} + +template +void require_all_of(std::map const &m, F &&f) { + for (auto const &[k, v] : m) { + f(k, v); + } +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/require_only_key.h b/lib/utils/include/utils/containers/require_only_key.h index c63ff4d440..ef142921ec 100644 --- a/lib/utils/include/utils/containers/require_only_key.h +++ b/lib/utils/include/utils/containers/require_only_key.h @@ -4,6 +4,7 @@ #include "utils/containers/contains_key.h" #include #include +#include namespace FlexFlow { @@ -15,6 +16,14 @@ V require_only_key(std::unordered_map const &m, K const &k) { return m.at(k); } +template +V require_only_key(std::map const &m, K const &k) { + ASSERT(m.size() == 1); + ASSERT(contains_key(m, k)); + + return m.at(k); +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/require_same.h b/lib/utils/include/utils/containers/require_same.h index 2f3439db32..2f6251064c 100644 --- a/lib/utils/include/utils/containers/require_same.h +++ b/lib/utils/include/utils/containers/require_same.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REQUIRE_SAME_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REQUIRE_SAME_H -#include "utils/exception.h" +#include #include namespace FlexFlow { diff --git a/lib/utils/include/utils/containers/transform.h b/lib/utils/include/utils/containers/transform.h index 14ef782690..bb34b2b5a5 100644 --- a/lib/utils/include/utils/containers/transform.h +++ b/lib/utils/include/utils/containers/transform.h @@ -2,13 +2,15 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_TRANSFORM_H #include "utils/containers/vector_transform.h" -#include "utils/required_core.h" #include #include #include #include #include #include +#include +#include +#include namespace FlexFlow { @@ -17,12 +19,6 @@ std::vector transform(std::vector const &v, F const &f) { return vector_transform(v, f); } -template -auto transform(req const &c, F const &f) - -> decltype(transform(std::declval(), std::declval())) { - return transform(static_cast(c), f); -} - template > std::unordered_set transform(std::unordered_set const &v, F const &f) { std::unordered_set result; @@ -86,8 +82,8 @@ template ::first_type, typename V2 = typename std::invoke_result_t::second_type> -std::unordered_map transform(std::map const &m, F const &f) { - std::unordered_map result; +std::map transform(std::map const &m, F const &f) { + std::map result; for (auto const &[k, v] : m) { result.insert(f(k, v)); } diff --git a/lib/utils/include/utils/containers/unordered_items.h b/lib/utils/include/utils/containers/unordered_items.h new file mode 100644 index 0000000000..1bd8da1498 --- /dev/null +++ b/lib/utils/include/utils/containers/unordered_items.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNORDERED_ITEMS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNORDERED_ITEMS_H + +#include +#include "utils/hash/pair.h" + +namespace FlexFlow { + +template +std::unordered_set> + unordered_items(C const &c) { + return {c.begin(), c.end()}; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/unordered_keys.h b/lib/utils/include/utils/containers/unordered_keys.h new file mode 100644 index 0000000000..e4b74a6f5e --- /dev/null +++ b/lib/utils/include/utils/containers/unordered_keys.h @@ -0,0 +1,30 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNORDERED_KEYS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNORDERED_KEYS_H + +#include +#include +#include + +namespace FlexFlow { + +template +std::unordered_set unordered_keys(std::unordered_map const &c) { + std::unordered_set result; + for (auto const &kv : c) { + result.insert(kv.first); + } + return result; +} + +template +std::unordered_set unordered_keys(std::map const &c) { + std::unordered_set result; + for (auto const &kv : c) { + result.insert(kv.first); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/unordered_map_from_map.h b/lib/utils/include/utils/containers/unordered_map_from_map.h new file mode 100644 index 0000000000..4410e5a767 --- /dev/null +++ b/lib/utils/include/utils/containers/unordered_map_from_map.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNORDERED_MAP_FROM_MAP_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNORDERED_MAP_FROM_MAP_H + +#include +#include + +namespace FlexFlow { + +template +std::unordered_map unordered_map_from_map(std::map const &u) { + std::unordered_map result{u.cbegin(), u.cend()}; + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/zip_values_strict.h b/lib/utils/include/utils/containers/zip_values_strict.h index 60a7985bc5..1a3ce95eb1 100644 --- a/lib/utils/include/utils/containers/zip_values_strict.h +++ b/lib/utils/include/utils/containers/zip_values_strict.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_VALUES_STRICT_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_VALUES_STRICT_H -#include "utils/containers/generate_map.h" -#include "utils/containers/keys.h" +#include "utils/containers/generate_unordered_map.h" +#include "utils/containers/unordered_keys.h" #include "utils/containers/require_same.h" #include #include @@ -14,9 +14,9 @@ std::unordered_map> zip_values_strict(std::unordered_map const &m1, std::unordered_map const &m2) { - ASSERT(keys(m1) == keys(m2)); + ASSERT(unordered_keys(m1) == unordered_keys(m2)); - return generate_map(require_same(keys(m1), keys(m2)), [&](K const &k) { + return generate_unordered_map(require_same(unordered_keys(m1), unordered_keys(m2)), [&](K const &k) { return std::pair{ m1.at(k), m2.at(k), diff --git a/lib/utils/include/utils/containers/zip_values_strict_with.h b/lib/utils/include/utils/containers/zip_values_strict_with.h index 3b0530db8a..5fc4bb7f5b 100644 --- a/lib/utils/include/utils/containers/zip_values_strict_with.h +++ b/lib/utils/include/utils/containers/zip_values_strict_with.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_VALUES_STRICT_WITH_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_VALUES_STRICT_WITH_H -#include "utils/containers/generate_map.h" -#include "utils/containers/keys.h" +#include "utils/containers/generate_unordered_map.h" +#include "utils/containers/unordered_keys.h" #include "utils/containers/require_same.h" #include #include @@ -19,9 +19,9 @@ std::unordered_map std::unordered_map const &m2, F &&f) { - ASSERT(keys(m1) == keys(m2)); + ASSERT(unordered_keys(m1) == unordered_keys(m2)); - return generate_map(require_same(keys(m1), keys(m2)), + return generate_unordered_map(require_same(unordered_keys(m1), unordered_keys(m2)), [&](K const &k) -> Out { return f(m1.at(k), m2.at(k)); }); } diff --git a/lib/utils/include/utils/fmt/map.h b/lib/utils/include/utils/fmt/map.h index 9225040d4d..5b0d41a9cc 100644 --- a/lib/utils/include/utils/fmt/map.h +++ b/lib/utils/include/utils/fmt/map.h @@ -22,10 +22,8 @@ struct formatter< CHECK_FMTABLE(K); CHECK_FMTABLE(V); - std::vector> items = ::FlexFlow::sorted(m); - std::string result = ::FlexFlow::join_strings( - items.cbegin(), items.cend(), ", ", [](std::pair const &p) { + m.cbegin(), m.cend(), ", ", [](std::pair const &p) { return fmt::to_string(p); }); diff --git a/lib/utils/include/utils/graph/instances/unordered_set_kwarg_dataflow_graph.h b/lib/utils/include/utils/graph/instances/unordered_set_kwarg_dataflow_graph.h index 418346bb36..4bb8373666 100644 --- a/lib/utils/include/utils/graph/instances/unordered_set_kwarg_dataflow_graph.h +++ b/lib/utils/include/utils/graph/instances/unordered_set_kwarg_dataflow_graph.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_INSTANCES_UNORDERED_SET_KWARG_DATAFLOW_GRAPH_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_INSTANCES_UNORDERED_SET_KWARG_DATAFLOW_GRAPH_H -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/containers/set_union.h" #include "utils/containers/values.h" #include "utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_edges.h" @@ -26,7 +26,7 @@ struct UnorderedSetKwargDataflowGraph final Node new_node = this->node_source.new_node(); std::unordered_map> outputs = - generate_map( + generate_unordered_map( output_slots, [&](SlotName const &output_slot) -> KwargDataflowOutput { KwargDataflowOutput output = diff --git a/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h b/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h index 159778bb6d..f05e6cb58a 100644 --- a/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h +++ b/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h @@ -4,8 +4,8 @@ #include "utils/containers/count.h" #include "utils/containers/enumerate_vector.h" #include "utils/containers/filter.h" -#include "utils/containers/generate_map.h" -#include "utils/containers/keys.h" +#include "utils/containers/generate_unordered_map.h" +#include "utils/containers/unordered_keys.h" #include "utils/containers/map_keys.h" #include "utils/containers/transform.h" #include "utils/containers/without_nullopts.h" @@ -23,6 +23,7 @@ #include "utils/graph/open_dataflow_graph/dataflow_graph_input_source.h" #include "utils/graph/open_dataflow_graph/open_dataflow_edge.h" #include "utils/graph/open_dataflow_graph/open_dataflow_edge_query.h" +#include "utils/containers/unordered_keys.h" namespace FlexFlow { @@ -81,7 +82,7 @@ struct UnorderedSetLabelledOpenDataflowGraph final } std::unordered_set query_nodes(NodeQuery const &q) const override { - return filter(keys(this->nodes), + return filter(unordered_keys(this->nodes), [&](Node const &n) { return includes(q.nodes, n); }); } @@ -95,7 +96,7 @@ struct UnorderedSetLabelledOpenDataflowGraph final std::unordered_set query_outputs(DataflowOutputQuery const &q) const override { return without_nullopts(transform( - keys(this->values), + unordered_keys(this->values), [&](OpenDataflowValue const &v) -> std::optional { if (!v.has()) { return std::nullopt; @@ -128,12 +129,12 @@ struct UnorderedSetLabelledOpenDataflowGraph final std::unordered_set outputs = get_all_dataflow_outputs(view); std::unordered_set edges = get_edges(view); std::unordered_map labelled_outputs = - generate_map(outputs, + generate_unordered_map(outputs, [&](DataflowOutput const &o) { return view.at(o); }); this->inputs.clear(); this->nodes = - generate_map(nodes, [&](Node const &n) { return view.at(n); }); + generate_unordered_map(nodes, [&](Node const &n) { return view.at(n); }); this->edges = transform( edges, [](DataflowEdge const &e) { return OpenDataflowEdge{e}; }); this->values = map_keys(labelled_outputs, [](DataflowOutput const &o) { @@ -145,14 +146,14 @@ struct UnorderedSetLabelledOpenDataflowGraph final LabelledOpenDataflowGraphView const &view) override { - std::unordered_map nodes = generate_map( + std::unordered_map nodes = generate_unordered_map( get_nodes(view), [&](Node const &n) { return view.at(n); }); std::unordered_set edges = get_edges(view); std::unordered_set inputs = ::FlexFlow::get_open_dataflow_graph_inputs(view); std::unordered_map values = - generate_map(get_open_dataflow_values(view), + generate_unordered_map(get_open_dataflow_values(view), [&](OpenDataflowValue const &v) { return view.at(v); }); this->inputs = inputs; diff --git a/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_kwarg_dataflow_graph.h b/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_kwarg_dataflow_graph.h index 2b20b94c96..ab60e3d364 100644 --- a/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_kwarg_dataflow_graph.h +++ b/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_kwarg_dataflow_graph.h @@ -4,7 +4,7 @@ #include "utils/containers/contains_key.h" #include "utils/containers/enumerate.h" #include "utils/containers/extend.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/containers/map_values.h" #include "utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_edges.h" #include "utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_outputs.h" @@ -17,6 +17,7 @@ #include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_all_open_kwarg_dataflow_edges.h" #include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge.h" #include "utils/overload.h" +#include "utils/containers/unordered_keys.h" namespace FlexFlow { @@ -68,8 +69,8 @@ struct UnorderedSetLabelledOpenKwargDataflowGraph final } std::unordered_map> outputs = - generate_map( - keys(output_labels), + generate_unordered_map( + unordered_keys(output_labels), [&](SlotName const &output_slot) -> KwargDataflowOutput { ValueLabel value_label = output_labels.at(output_slot); @@ -106,7 +107,7 @@ struct UnorderedSetLabelledOpenKwargDataflowGraph final } std::unordered_set query_nodes(NodeQuery const &q) const override { - return filter(keys(this->nodes), + return filter(unordered_keys(this->nodes), [&](Node const &n) { return includes(q.nodes, n); }); } @@ -122,7 +123,7 @@ struct UnorderedSetLabelledOpenKwargDataflowGraph final std::unordered_set> query_outputs( KwargDataflowOutputQuery const &q) const override { - return filter(keys(this->outputs), + return filter(unordered_keys(this->outputs), [&](KwargDataflowOutput const &output) { return kwarg_dataflow_output_query_includes(q, output); }); @@ -130,7 +131,7 @@ struct UnorderedSetLabelledOpenKwargDataflowGraph final std::unordered_set> get_inputs() const override { - return keys(this->graph_inputs); + return unordered_keys(this->graph_inputs); } NodeLabel at(Node const &n) const override { @@ -159,7 +160,7 @@ struct UnorderedSetLabelledOpenKwargDataflowGraph final this->graph_inputs.clear(); this->nodes = - generate_map(view_nodes, [&](Node const &n) { return view.at(n); }); + generate_unordered_map(view_nodes, [&](Node const &n) { return view.at(n); }); this->edges = transform(view_edges, @@ -168,7 +169,7 @@ struct UnorderedSetLabelledOpenKwargDataflowGraph final return OpenKwargDataflowEdge{e}; }); this->outputs = - generate_map(view_outputs, [&](KwargDataflowOutput const &o) { + generate_unordered_map(view_outputs, [&](KwargDataflowOutput const &o) { return view.at(o); }); } @@ -186,16 +187,16 @@ struct UnorderedSetLabelledOpenKwargDataflowGraph final std::unordered_set> view_outputs = get_all_kwarg_dataflow_outputs(view); - this->graph_inputs = generate_map( + this->graph_inputs = generate_unordered_map( view_inputs, [&](KwargDataflowGraphInput const &i) { return view.at(OpenKwargDataflowValue{i}); }); this->nodes = - generate_map(view_nodes, [&](Node const &n) { return view.at(n); }); + generate_unordered_map(view_nodes, [&](Node const &n) { return view.at(n); }); this->edges = view_edges; this->outputs = - generate_map(view_outputs, [&](KwargDataflowOutput const &o) { + generate_unordered_map(view_outputs, [&](KwargDataflowOutput const &o) { return view.at(OpenKwargDataflowValue{o}); }); } diff --git a/lib/utils/include/utils/graph/instances/unordered_set_open_kwarg_dataflow_graph.h b/lib/utils/include/utils/graph/instances/unordered_set_open_kwarg_dataflow_graph.h index 3c66b2c689..9746533c17 100644 --- a/lib/utils/include/utils/graph/instances/unordered_set_open_kwarg_dataflow_graph.h +++ b/lib/utils/include/utils/graph/instances/unordered_set_open_kwarg_dataflow_graph.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_INSTANCES_UNORDERED_SET_OPEN_KWARG_DATAFLOW_GRAPH_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_INSTANCES_UNORDERED_SET_OPEN_KWARG_DATAFLOW_GRAPH_H -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output_query.h" #include "utils/graph/node/node_source.h" #include "utils/graph/open_kwarg_dataflow_graph/i_open_kwarg_dataflow_graph.h" @@ -36,7 +36,7 @@ struct UnorderedSetOpenKwargDataflowGraph final } std::unordered_map> outputs = - generate_map( + generate_unordered_map( output_slots, [&](SlotName const &output_slot) -> KwargDataflowOutput { KwargDataflowOutput output = diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_slots_for_node.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_slots_for_node.h index 32848f38a6..87ebab4c2d 100644 --- a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_slots_for_node.h +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_slots_for_node.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_INCOMING_SLOTS_FOR_NODE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_INCOMING_SLOTS_FOR_NODE_H -#include "utils/containers/keys.h" +#include "utils/containers/unordered_keys.h" #include "utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_kwarg_dataflow_edges_for_node.h" namespace FlexFlow { @@ -10,7 +10,7 @@ template std::unordered_set get_incoming_slots_for_node(KwargDataflowGraphView const &g, Node n) { - return keys(get_incoming_kwarg_dataflow_edges_for_node(g, n)); + return unordered_keys(get_incoming_kwarg_dataflow_edges_for_node(g, n)); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_slots_for_node.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_slots_for_node.h index 372dfed1e8..6cf2b4b8a4 100644 --- a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_slots_for_node.h +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_slots_for_node.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_OUTGOING_SLOTS_FOR_NODE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_OUTGOING_SLOTS_FOR_NODE_H -#include "utils/containers/keys.h" +#include "utils/containers/unordered_keys.h" #include "utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_outputs_for_node.h" namespace FlexFlow { @@ -10,7 +10,7 @@ template std::unordered_set get_outgoing_slots_for_node(KwargDataflowGraphView const &g, Node n) { - return keys(get_outgoing_kwarg_dataflow_outputs_for_node(g, n)); + return unordered_keys(get_outgoing_kwarg_dataflow_outputs_for_node(g, n)); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.h index 2115a03cda..502eeab73b 100644 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.h +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.h @@ -14,14 +14,14 @@ LabelledOpenDataflowGraphData get_graph_data( LabelledOpenDataflowGraphView const &g) { std::unordered_map node_data = - generate_map(get_nodes(g), [&](Node const &n) { return g.at(n); }); + generate_unordered_map(get_nodes(g), [&](Node const &n) { return g.at(n); }); std::unordered_set edges = get_edges(g); std::unordered_set inputs = g.get_inputs(); std::unordered_map value_data = - generate_map(get_open_dataflow_values(g), + generate_unordered_map(get_open_dataflow_values(g), [&](OpenDataflowValue const &v) { return g.at(v); }); return LabelledOpenDataflowGraphData{ diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/permute_input_ids.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/permute_input_ids.h index 88132e0a79..580a35b3f7 100644 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/permute_input_ids.h +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/permute_input_ids.h @@ -30,10 +30,10 @@ LabelledOpenDataflowGraphView permute_input_ids( }; std::unordered_map node_labels = - generate_map(get_nodes(permuted), [&](Node const &n) { return g.at(n); }); + generate_unordered_map(get_nodes(permuted), [&](Node const &n) { return g.at(n); }); std::unordered_map value_labels = - generate_map(get_open_dataflow_values(permuted), + generate_unordered_map(get_open_dataflow_values(permuted), [&](OpenDataflowValue const &new_value) { return g.at(old_value_from_new(new_value)); }); diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/permute_node_ids.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/permute_node_ids.h index 88950635d2..5119587654 100644 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/permute_node_ids.h +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/permute_node_ids.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_PERMUTE_NODE_IDS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_PERMUTE_NODE_IDS_H -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.h" #include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" #include "utils/graph/node/algorithms.h" @@ -37,12 +37,12 @@ LabelledOpenDataflowGraphView permute_node_ids( }; std::unordered_map node_labels = - generate_map(get_nodes(permuted), [&](Node const &new_node) { + generate_unordered_map(get_nodes(permuted), [&](Node const &new_node) { return g.at(old_node_from_new(new_node)); }); std::unordered_map value_labels = - generate_map(get_open_dataflow_values(permuted), + generate_unordered_map(get_open_dataflow_values(permuted), [&](OpenDataflowValue const &new_value) { return g.at(old_value_from_new(new_value)); }); diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.h index 92938d7142..fde90497e7 100644 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.h +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_LABELS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_LABELS_H -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.h" #include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" #include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h" @@ -27,9 +27,9 @@ LabelledOpenDataflowGraphView rewrite_labels( }; std::unordered_map node_labels = - generate_map(get_nodes(g), get_new_node_label); + generate_unordered_map(get_nodes(g), get_new_node_label); std::unordered_map value_labels = - generate_map(get_open_dataflow_values(g), get_new_value_label); + generate_unordered_map(get_open_dataflow_values(g), get_new_value_label); return with_labelling(g, node_labels, value_labels); } diff --git a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/get_labelled_open_kwarg_dataflow_graph_data.h b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/get_labelled_open_kwarg_dataflow_graph_data.h index d60c396274..e98b858019 100644 --- a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/get_labelled_open_kwarg_dataflow_graph_data.h +++ b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/get_labelled_open_kwarg_dataflow_graph_data.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_DATA_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_DATA_H -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_data.dtg.h" #include "utils/graph/labelled_open_kwarg_dataflow_graph/labelled_open_kwarg_dataflow_graph_view.h" #include "utils/graph/node/algorithms.h" @@ -28,12 +28,12 @@ LabelledOpenKwargDataflowGraphData{ - /*nodes=*/generate_map( + /*nodes=*/generate_unordered_map( get_nodes(g), [&](Node const &n) -> NodeLabel { return g.at(n); }), /*edges=*/get_all_open_kwarg_dataflow_edges(g), /*inputs=*/get_all_kwarg_dataflow_graph_inputs(g), /*outputs=*/ - generate_map( + generate_unordered_map( get_all_open_kwarg_dataflow_values(g), [&](OpenKwargDataflowValue const &v) -> ValueLabel { return g.at(v); }), diff --git a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_data.h b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_data.h index d06b96e37f..b6a4366fc2 100644 --- a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_data.h +++ b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_data.h @@ -21,12 +21,12 @@ OpenKwargDataflowGraphData SlotName> const &labelled_data) { OpenKwargDataflowGraphData result = OpenKwargDataflowGraphData{ - /*nodes=*/keys(labelled_data.node_data), + /*nodes=*/unordered_keys(labelled_data.node_data), /*edges=*/labelled_data.edges, /*inputs=*/labelled_data.inputs, /*outputs=*/ filtrans( - keys(labelled_data.value_data), + unordered_keys(labelled_data.value_data), [](OpenKwargDataflowValue const &v) { return v.try_require_internal(); }), diff --git a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/permute_labelled_open_kwarg_dataflow_graph_input_ids.h b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/permute_labelled_open_kwarg_dataflow_graph_input_ids.h index 223c2e7673..6e1c8bdc57 100644 --- a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/permute_labelled_open_kwarg_dataflow_graph_input_ids.h +++ b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/permute_labelled_open_kwarg_dataflow_graph_input_ids.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_PERMUTE_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_INPUT_IDS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_PERMUTE_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_INPUT_IDS_H -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_view_with_labelling.h" #include "utils/graph/labelled_open_kwarg_dataflow_graph/labelled_open_kwarg_dataflow_graph_view.h" #include "utils/graph/node/algorithms.h" @@ -57,11 +57,11 @@ LabelledOpenKwargDataflowGraphView node_labels = - generate_map(get_nodes(permuted), [&](Node const &n) { return g.at(n); }); + generate_unordered_map(get_nodes(permuted), [&](Node const &n) { return g.at(n); }); std::unordered_map, ValueLabel> - value_labels = generate_map( + value_labels = generate_unordered_map( get_all_open_kwarg_dataflow_values(permuted), [&](OpenKwargDataflowValue const &new_value) { return g.at(old_value_from_new(new_value)); }); diff --git a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/permute_labelled_open_kwarg_dataflow_graph_node_ids.h b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/permute_labelled_open_kwarg_dataflow_graph_node_ids.h index 06728949df..d89c935d52 100644 --- a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/permute_labelled_open_kwarg_dataflow_graph_node_ids.h +++ b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/permute_labelled_open_kwarg_dataflow_graph_node_ids.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_PERMUTE_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_NODE_IDS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_PERMUTE_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_NODE_IDS_H -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_view_with_labelling.h" #include "utils/graph/labelled_open_kwarg_dataflow_graph/labelled_open_kwarg_dataflow_graph_view.h" #include "utils/graph/node/algorithms/new_node.dtg.h" @@ -56,13 +56,13 @@ LabelledOpenKwargDataflowGraphView node_labels = - generate_map(get_nodes(permuted), [&](Node const &new_node) { + generate_unordered_map(get_nodes(permuted), [&](Node const &new_node) { return g.at(old_node_from_new(new_node)); }); std::unordered_map, ValueLabel> - value_labels = generate_map( + value_labels = generate_unordered_map( get_all_open_kwarg_dataflow_values(permuted), [&](OpenKwargDataflowValue const &new_value) { return g.at(old_value_from_new(new_value)); }); diff --git a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/rewrite_labelled_open_kwarg_dataflow_graph_labels.h b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/rewrite_labelled_open_kwarg_dataflow_graph_labels.h index a632cd7b64..d5d1435462 100644 --- a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/rewrite_labelled_open_kwarg_dataflow_graph_labels.h +++ b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/rewrite_labelled_open_kwarg_dataflow_graph_labels.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_LABELS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_LABELS_H -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_view_with_labelling.h" #include "utils/graph/labelled_open_kwarg_dataflow_graph/labelled_open_kwarg_dataflow_graph.h" #include "utils/graph/node/algorithms.h" @@ -39,10 +39,10 @@ LabelledOpenKwargDataflowGraphView NewValueLabel { return f(v, g.at(v)); }; std::unordered_map node_labels = - generate_map(get_nodes(g), get_new_node_label); + generate_unordered_map(get_nodes(g), get_new_node_label); std::unordered_map, NewValueLabel> - value_labels = generate_map(get_all_open_kwarg_dataflow_values(g), + value_labels = generate_unordered_map(get_all_open_kwarg_dataflow_values(g), get_new_value_label); return open_kwarg_dataflow_graph_view_with_labelling( g, node_labels, value_labels); diff --git a/lib/utils/include/utils/graph/series_parallel/get_ancestors.h b/lib/utils/include/utils/graph/series_parallel/get_ancestors.h index f8e9f52cb5..e658c04060 100644 --- a/lib/utils/include/utils/graph/series_parallel/get_ancestors.h +++ b/lib/utils/include/utils/graph/series_parallel/get_ancestors.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLLEL_GET_ANCESTORS_H #include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include namespace FlexFlow { diff --git a/lib/utils/include/utils/graph/series_parallel/non_normal_parallel_split.dtg.toml b/lib/utils/include/utils/graph/series_parallel/non_normal_parallel_split.dtg.toml index b4b975bb4e..a262de1a6b 100644 --- a/lib/utils/include/utils/graph/series_parallel/non_normal_parallel_split.dtg.toml +++ b/lib/utils/include/utils/graph/series_parallel/non_normal_parallel_split.dtg.toml @@ -4,6 +4,7 @@ type = "struct" features = [ "eq", "hash", + "ord", "fmt", ] @@ -16,19 +17,19 @@ post_includes = [ ] includes = [ - "", + "", "", "utils/graph/node/node.dtg.h", ] src_includes = [ "utils/fmt/variant.h", - "utils/fmt/unordered_multiset.h", - "utils/hash/unordered_multiset.h", + "utils/fmt/multiset.h", + "utils/hash/multiset.h", ] [[fields]] name = "children" -type = "std::unordered_multiset>" +type = "std::multiset>" indirect = true diff --git a/lib/utils/include/utils/graph/series_parallel/non_normal_series_split.dtg.toml b/lib/utils/include/utils/graph/series_parallel/non_normal_series_split.dtg.toml index 008e58dc3f..8f0de7082d 100644 --- a/lib/utils/include/utils/graph/series_parallel/non_normal_series_split.dtg.toml +++ b/lib/utils/include/utils/graph/series_parallel/non_normal_series_split.dtg.toml @@ -5,6 +5,7 @@ features = [ "eq", "hash", "fmt", + "ord", ] fwd_decls = [ @@ -23,6 +24,7 @@ includes = [ src_includes = [ "utils/fmt/variant.h", + "utils/ord/vector.h", "utils/fmt/vector.h", "utils/hash/vector.h", ] diff --git a/lib/utils/include/utils/graph/series_parallel/non_normal_sp_decomposition.dtg.toml b/lib/utils/include/utils/graph/series_parallel/non_normal_sp_decomposition.dtg.toml index c82e771385..d6289f1d75 100644 --- a/lib/utils/include/utils/graph/series_parallel/non_normal_sp_decomposition.dtg.toml +++ b/lib/utils/include/utils/graph/series_parallel/non_normal_sp_decomposition.dtg.toml @@ -3,6 +3,7 @@ name = "NonNormalSPDecomposition" type = "variant" features = [ "eq", + "ord", "hash", "fmt", ] diff --git a/lib/utils/include/utils/graph/series_parallel/parallel_split.dtg.toml b/lib/utils/include/utils/graph/series_parallel/parallel_split.dtg.toml index a3315d506b..eb907d4d43 100644 --- a/lib/utils/include/utils/graph/series_parallel/parallel_split.dtg.toml +++ b/lib/utils/include/utils/graph/series_parallel/parallel_split.dtg.toml @@ -3,6 +3,7 @@ name = "ParallelSplit" type = "struct" features = [ "eq", + "ord", "hash", "fmt", ] @@ -16,18 +17,18 @@ post_includes = [ ] includes = [ - "", + "", "", "utils/graph/node/node.dtg.h", ] src_includes = [ "utils/fmt/variant.h", - "utils/fmt/unordered_multiset.h", - "utils/hash/unordered_multiset.h", + "utils/fmt/multiset.h", + "utils/hash/multiset.h", ] [[fields]] name = "children" -type = "std::unordered_multiset>" +type = "std::multiset>" indirect = true diff --git a/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.dtg.toml b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.dtg.toml index 4635bdd877..b47a8eabaa 100644 --- a/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.dtg.toml +++ b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.dtg.toml @@ -3,6 +3,7 @@ name = "SeriesParallelDecomposition" type = "variant" features = [ "eq", + "ord", "hash", "fmt", ] diff --git a/lib/utils/include/utils/graph/series_parallel/series_split.dtg.toml b/lib/utils/include/utils/graph/series_parallel/series_split.dtg.toml index e37762a059..cb5753627d 100644 --- a/lib/utils/include/utils/graph/series_parallel/series_split.dtg.toml +++ b/lib/utils/include/utils/graph/series_parallel/series_split.dtg.toml @@ -3,6 +3,7 @@ name = "SeriesSplit" type = "struct" features = [ "eq", + "ord", "hash", "fmt", ] @@ -23,6 +24,7 @@ includes = [ src_includes = [ "utils/fmt/variant.h", + "utils/ord/vector.h", "utils/fmt/vector.h", "utils/hash/vector.h", ] diff --git a/lib/utils/include/utils/json/check_is_jsonable.h b/lib/utils/include/utils/json/check_is_jsonable.h index 9d4f65f005..8597e11c22 100644 --- a/lib/utils/include/utils/json/check_is_jsonable.h +++ b/lib/utils/include/utils/json/check_is_jsonable.h @@ -7,9 +7,9 @@ namespace FlexFlow { #define CHECK_IS_JSONABLE(...) \ - static_assert(is_json_serializable<__VA_ARGS__>::value, \ + static_assert(::FlexFlow::is_json_serializable<__VA_ARGS__>::value, \ #__VA_ARGS__ " should be json serializeable"); \ - static_assert(is_json_deserializable<__VA_ARGS__>::value, \ + static_assert(::FlexFlow::is_json_deserializable<__VA_ARGS__>::value, \ #__VA_ARGS__ " should be json deserializeable") } // namespace FlexFlow diff --git a/lib/utils/include/utils/many_to_one/many_to_one.h b/lib/utils/include/utils/many_to_one/many_to_one.h index 2d078eb304..a501a0672c 100644 --- a/lib/utils/include/utils/many_to_one/many_to_one.h +++ b/lib/utils/include/utils/many_to_one/many_to_one.h @@ -1,7 +1,6 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_MANY_TO_ONE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_MANY_TO_ONE_H -#include "utils/containers/keys.h" #include "utils/containers/require_same.h" #include "utils/containers/try_at.h" #include "utils/containers/unordered_set_of.h" @@ -20,6 +19,8 @@ #include #include #include +#include "utils/containers/set_of.h" +#include "utils/containers/unordered_keys.h" namespace FlexFlow { @@ -88,7 +89,7 @@ struct ManyToOne { } std::unordered_set left_values() const { - return keys(this->m_l_to_r); + return unordered_keys(this->m_l_to_r); } std::unordered_set> left_groups() const { @@ -96,7 +97,7 @@ struct ManyToOne { } std::unordered_set right_values() const { - return keys(this->m_r_to_l); + return unordered_keys(this->m_r_to_l); } std::unordered_map const &l_to_r() const { @@ -141,6 +142,22 @@ std::ostream &operator<<(std::ostream &s, ManyToOne const &m) { return (s << fmt::to_string(m)); } +template +std::unordered_set> + unstructured_relation_from_many_to_one(ManyToOne const &many_to_one) { + return unordered_set_of(many_to_one.l_to_r()); +} + +template +ManyToOne many_to_one_from_unstructured_relation( + std::unordered_set> const &relation) { + ManyToOne result; + for (auto const &lr : relation) { + result.insert(lr); + } + return result; +} + } // namespace FlexFlow namespace nlohmann { @@ -151,14 +168,16 @@ struct adl_serializer<::FlexFlow::ManyToOne> { CHECK_IS_JSON_DESERIALIZABLE(L); CHECK_IS_JSON_DESERIALIZABLE(R); - NOT_IMPLEMENTED(); + std::unordered_set> s = j; + + return ::FlexFlow::many_to_one_from_unstructured_relation(s); } static void to_json(json &j, ::FlexFlow::ManyToOne const &m) { CHECK_IS_JSON_SERIALIZABLE(L); CHECK_IS_JSON_SERIALIZABLE(R); - NOT_IMPLEMENTED(); + j = ::FlexFlow::set_of(::FlexFlow::unstructured_relation_from_many_to_one(m)); } }; diff --git a/lib/utils/include/utils/many_to_one/many_to_one_from_unstructured_relation.h b/lib/utils/include/utils/many_to_one/many_to_one_from_unstructured_relation.h deleted file mode 100644 index 171c6c15d6..0000000000 --- a/lib/utils/include/utils/many_to_one/many_to_one_from_unstructured_relation.h +++ /dev/null @@ -1,20 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_MANY_TO_ONE_FROM_UNSTRUCTURED_RELATION_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_MANY_TO_ONE_FROM_UNSTRUCTURED_RELATION_H - -#include "utils/many_to_one/many_to_one.h" - -namespace FlexFlow { - -template -ManyToOne many_to_one_from_unstructured_relation( - std::unordered_set> const &relation) { - ManyToOne result; - for (auto const &lr : relation) { - result.insert(lr); - } - return result; -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/many_to_one/unstructured_relation_from_many_to_one.h b/lib/utils/include/utils/many_to_one/unstructured_relation_from_many_to_one.h deleted file mode 100644 index 676c5efa5d..0000000000 --- a/lib/utils/include/utils/many_to_one/unstructured_relation_from_many_to_one.h +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_UNSTRUCTURED_RELATION_FROM_MANY_TO_ONE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_UNSTRUCTURED_RELATION_FROM_MANY_TO_ONE_H - -#include "utils/containers/unordered_set_of.h" -#include "utils/many_to_one/many_to_one.h" - -namespace FlexFlow { - -template -std::unordered_set> - unstructured_relation_from_many_to_one(ManyToOne const &many_to_one) { - return unordered_set_of(many_to_one.l_to_r()); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/nonempty_set/nonempty_set.h b/lib/utils/include/utils/nonempty_set/nonempty_set.h new file mode 100644 index 0000000000..93da743592 --- /dev/null +++ b/lib/utils/include/utils/nonempty_set/nonempty_set.h @@ -0,0 +1,136 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_NONEMPTY_SET_NONEMPTY_SET_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_NONEMPTY_SET_NONEMPTY_SET_H + +#include +#include +#include "utils/hash-utils.h" +#include "utils/hash/set.h" +#include "utils/fmt/set.h" +#include "utils/positive_int/positive_int.h" +#include "utils/containers/unordered_set_of.h" + +namespace FlexFlow { + +template +struct nonempty_set { +public: + nonempty_set() = delete; + + nonempty_set(std::initializer_list const &vs) : raw(vs) { + ASSERT(this->raw.size() > 0); + } + + explicit nonempty_set(std::set const &s) : raw(s) { + ASSERT(this->raw.size() > 0); + } + + bool operator==(nonempty_set const &other) const { + return this->unwrap_as_set() == other.unwrap_as_set(); + } + + bool operator!=(nonempty_set const &other) const { + return this->unwrap_as_set() != other.unwrap_as_set(); + } + + bool operator<(nonempty_set const &other) const { + return this->unwrap_as_set() < other.unwrap_as_set(); + } + + bool operator<=(nonempty_set const &other) const { + return this->unwrap_as_set() <= other.unwrap_as_set(); + } + + bool operator>(nonempty_set const &other) const { + return this->unwrap_as_set() > other.unwrap_as_set(); + } + + bool operator>=(nonempty_set const &other) const { + return this->unwrap_as_set() >= other.unwrap_as_set(); + } + + bool operator==(std::set const &other) const { + return this->unwrap_as_set() == other; + } + + bool operator!=(std::set const &other) const { + return this->unwrap_as_set() != other; + } + + void insert(T const &t) { + this->raw.insert(t); + } + + size_t size() const { + return this->raw.size(); + }; + + positive_int num_elements() const { + return positive_int{this->raw.size()}; + }; + + std::set const &unwrap_as_set() const { + return this->raw; + } + + std::unordered_set unwrap_as_unordered_set() const { + return unordered_set_of(this->raw); + } + + using value_type = T; + + typename std::set::const_iterator begin() const { + return this->raw.cbegin(); + } + + typename std::set::const_iterator cbegin() const { + return this->raw.cbegin(); + } + + typename std::set::const_iterator end() const { + return this->raw.cend(); + } + + typename std::set::const_iterator cend() const { + return this->raw.cend(); + } + +private: + std::set raw; +}; + +template +bool operator==(std::set const &lhs, + nonempty_set const &rhs) { + return lhs == rhs.unwrap_as_set(); +} + +template +bool operator!=(std::set const &lhs, + nonempty_set const &rhs) { + return lhs != rhs.unwrap_as_set(); +} + +template +std::set format_as(nonempty_set const &s) { + return s.unwrap_as_set(); +} + +template +std::ostream &operator<<(std::ostream &s, nonempty_set const &m) { + return (s << fmt::to_string(m)); +} + +} // namespace FlexFlow + +namespace std { + +template +struct hash<::FlexFlow::nonempty_set> { + size_t operator()(::FlexFlow::nonempty_set const &x) const { + return ::FlexFlow::get_std_hash(x.unwrap_as_set()); + }; +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/one_to_many/one_to_many.h b/lib/utils/include/utils/one_to_many/one_to_many.h index 5492ff3f78..d57622b950 100644 --- a/lib/utils/include/utils/one_to_many/one_to_many.h +++ b/lib/utils/include/utils/one_to_many/one_to_many.h @@ -7,23 +7,23 @@ #include "utils/containers/require_same.h" #include "utils/containers/transform.h" #include "utils/containers/try_at.h" -#include "utils/containers/unordered_set_of.h" #include "utils/containers/values.h" #include "utils/exception.h" -#include "utils/fmt/unordered_map.h" -#include "utils/fmt/unordered_set.h" +#include "utils/fmt/map.h" +#include "utils/fmt/set.h" #include "utils/hash-utils.h" #include "utils/hash/tuple.h" -#include "utils/hash/unordered_map.h" -#include "utils/hash/unordered_set.h" +#include "utils/hash/map.h" +#include "utils/hash/set.h" #include "utils/json/check_is_json_deserializable.h" #include "utils/json/check_is_json_serializable.h" -#include "utils/nonempty_unordered_set/nonempty_unordered_set.h" +#include "utils/nonempty_set/nonempty_set.h" #include #include #include -#include -#include +#include +#include +#include "utils/containers/set_of.h" namespace FlexFlow { @@ -54,6 +54,22 @@ struct OneToMany { return this->tie() != other.tie(); } + bool operator<(OneToMany const &other) const { + return this->tie() < other.tie(); + } + + bool operator<=(OneToMany const &other) const { + return this->tie() <= other.tie(); + } + + bool operator>(OneToMany const &other) const { + return this->tie() > other.tie(); + } + + bool operator>=(OneToMany const &other) const { + return this->tie() >= other.tie(); + } + void insert(std::pair const &p) { L l = p.first; R r = p.second; @@ -66,7 +82,7 @@ struct OneToMany { if (contains_key(this->m_l_to_r, l)) { this->m_l_to_r.at(l).insert(r); } else { - this->m_l_to_r.insert({l, nonempty_unordered_set{{r}}}); + this->m_l_to_r.insert({l, nonempty_set{{r}}}); } } else if (found_l.value() == l) { return; @@ -80,14 +96,14 @@ struct OneToMany { } } - std::unordered_set> relation() const { + std::set> relation() const { return transform(items(this->m_r_to_l), [](std::pair const &p) -> std::pair { return {p.second, p.first}; }); } - nonempty_unordered_set const &at_l(L const &l) const { + nonempty_set const &at_l(L const &l) const { return this->m_l_to_r.at(l); } @@ -95,23 +111,23 @@ struct OneToMany { return this->m_r_to_l.at(r); } - std::unordered_set left_values() const { + std::set left_values() const { return keys(this->m_l_to_r); } - std::unordered_set right_values() const { + std::set right_values() const { return keys(this->m_r_to_l); } - std::unordered_set> right_groups() const { - return unordered_set_of(values(this->m_l_to_r)); + std::set> right_groups() const { + return set_of(values(this->m_l_to_r)); } - std::unordered_map> const &l_to_r() const { + std::map> const &l_to_r() const { return this->m_l_to_r; } - std::unordered_map const &r_to_l() const { + std::map const &r_to_l() const { return this->m_r_to_l; } @@ -120,8 +136,8 @@ struct OneToMany { } private: - std::unordered_map> m_l_to_r; - std::unordered_map m_r_to_l; + std::map> m_l_to_r; + std::map m_r_to_l; private: std::tuple @@ -133,7 +149,7 @@ struct OneToMany { }; template -std::unordered_map> +std::map> format_as(OneToMany const &m) { return generate_map(m.left_values(), [&](L const &l) { return m.at_l(l); }); } @@ -143,6 +159,25 @@ std::ostream &operator<<(std::ostream &s, OneToMany const &m) { return (s << fmt::to_string(m)); } +template +std::unordered_set> + unstructured_relation_from_one_to_many(OneToMany const &one_to_many) { + return transform(unordered_set_of(one_to_many.r_to_l()), + [](std::pair const &rl) -> std::pair { + return std::pair{rl.second, rl.first}; + }); +} + +template +OneToMany one_to_many_from_unstructured_relation( + std::unordered_set> const &rel) { + OneToMany result; + for (auto const &lr : rel) { + result.insert(lr); + } + return result; +} + } // namespace FlexFlow namespace nlohmann { @@ -153,14 +188,16 @@ struct adl_serializer<::FlexFlow::OneToMany> { CHECK_IS_JSON_DESERIALIZABLE(L); CHECK_IS_JSON_DESERIALIZABLE(R); - NOT_IMPLEMENTED(); + std::unordered_set> s = j; + + return ::FlexFlow::one_to_many_from_unstructured_relation(s); } static void to_json(json &j, ::FlexFlow::OneToMany const &m) { CHECK_IS_JSON_SERIALIZABLE(L); CHECK_IS_JSON_SERIALIZABLE(R); - NOT_IMPLEMENTED(); + j = ::FlexFlow::set_of(::FlexFlow::unstructured_relation_from_one_to_many(m)); } }; diff --git a/lib/utils/include/utils/one_to_many/one_to_many_from_unstructured_relation.h b/lib/utils/include/utils/one_to_many/one_to_many_filter_keys.h similarity index 50% rename from lib/utils/include/utils/one_to_many/one_to_many_from_unstructured_relation.h rename to lib/utils/include/utils/one_to_many/one_to_many_filter_keys.h index 11c0a767d6..5f31fbb584 100644 --- a/lib/utils/include/utils/one_to_many/one_to_many_from_unstructured_relation.h +++ b/lib/utils/include/utils/one_to_many/one_to_many_filter_keys.h @@ -1,16 +1,17 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_ONE_TO_MANY_FROM_UNSTRUCTURED_RELATION_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_ONE_TO_MANY_FROM_UNSTRUCTURED_RELATION_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_ONE_TO_MANY_FILTER_KEYS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_ONE_TO_MANY_FILTER_KEYS_H #include "utils/one_to_many/one_to_many.h" namespace FlexFlow { -template -OneToMany one_to_many_from_unstructured_relation( - std::unordered_set> const &rel) { +template +OneToMany one_to_many_filter_keys(OneToMany const &m, F &&f) { OneToMany result; - for (auto const &lr : rel) { - result.insert(lr); + for (auto const &kv : unstructured_relation_from_one_to_many(m)) { + if (f(kv.first)) { + result.insert(kv); + } } return result; } diff --git a/lib/utils/include/utils/one_to_many/one_to_many_filter_values.h b/lib/utils/include/utils/one_to_many/one_to_many_filter_values.h new file mode 100644 index 0000000000..4694b06d65 --- /dev/null +++ b/lib/utils/include/utils/one_to_many/one_to_many_filter_values.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_ONE_TO_MANY_FILTER_VALUES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_ONE_TO_MANY_FILTER_VALUES_H + +#include "utils/one_to_many/one_to_many.h" + +namespace FlexFlow { + +template +OneToMany one_to_many_filter_values(OneToMany const &m, F &&f) { + OneToMany result; + for (auto const &kv : unstructured_relation_from_one_to_many(m)) { + if (f(kv.second)) { + result.insert(kv); + } + } + return result; +} + +} // namespace FlexFlow + + +#endif diff --git a/lib/utils/include/utils/one_to_many/one_to_many_transform_values.h b/lib/utils/include/utils/one_to_many/one_to_many_transform_values.h index a9afe98988..050f6ad7dd 100644 --- a/lib/utils/include/utils/one_to_many/one_to_many_transform_values.h +++ b/lib/utils/include/utils/one_to_many/one_to_many_transform_values.h @@ -2,7 +2,8 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_ONE_TO_MANY_TRANSFORM_VALUES_H #include "utils/containers/transform.h" -#include "utils/one_to_many/one_to_many_from_unstructured_relation.h" +#include "utils/one_to_many/one_to_many.h" + namespace FlexFlow { @@ -13,7 +14,8 @@ template one_to_many_transform_values(OneToMany const &input, F f) { return one_to_many_from_unstructured_relation(transform( - input.relation(), [&](std::pair const &p) -> std::pair { + unordered_set_of(input.relation()), + [&](std::pair const &p) -> std::pair { return {p.first, f(p.second)}; })); } diff --git a/lib/utils/include/utils/one_to_many/unstructured_relation_from_one_to_many.h b/lib/utils/include/utils/one_to_many/unstructured_relation_from_one_to_many.h deleted file mode 100644 index 02ed225610..0000000000 --- a/lib/utils/include/utils/one_to_many/unstructured_relation_from_one_to_many.h +++ /dev/null @@ -1,21 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_UNSTRUCTURED_RELATION_FROM_ONE_TO_MANY_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_UNSTRUCTURED_RELATION_FROM_ONE_TO_MANY_H - -#include "utils/containers/transform.h" -#include "utils/containers/unordered_set_of.h" -#include "utils/one_to_many/one_to_many.h" - -namespace FlexFlow { - -template -std::unordered_set> - unstructured_relation_from_one_to_many(OneToMany const &one_to_many) { - return transform(unordered_set_of(one_to_many.r_to_l()), - [](std::pair const &rl) -> std::pair { - return std::pair{rl.second, rl.first}; - }); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/orthotope/dim_coord.h b/lib/utils/include/utils/orthotope/dim_coord.h index 87a05a7315..b9a10d6750 100644 --- a/lib/utils/include/utils/orthotope/dim_coord.h +++ b/lib/utils/include/utils/orthotope/dim_coord.h @@ -3,10 +3,10 @@ #include "utils/containers/all_of.h" #include "utils/containers/contains_key.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/containers/get_all_assignments.h" #include "utils/containers/is_subseteq_of.h" -#include "utils/containers/keys.h" +#include "utils/containers/unordered_keys.h" #include "utils/containers/map_from_keys_and_values.h" #include "utils/containers/map_values.h" #include "utils/containers/product.h" @@ -29,7 +29,7 @@ namespace FlexFlow { template std::unordered_set get_coord_dims(DimCoord const &coord) { - return keys(coord.raw); + return unordered_keys(coord.raw); } template @@ -65,7 +65,7 @@ DimCoord lift_dim_coord(DimCoord const &coord, ASSERT(is_subseteq_of(get_coord_dims(coord), lifted_dims)); return DimCoord{ - generate_map(lifted_dims, + generate_unordered_map(lifted_dims, [&](T const &dim) { if (contains_key(coord.raw, dim)) { return coord.raw.at(dim); diff --git a/lib/utils/include/utils/orthotope/dim_domain.h b/lib/utils/include/utils/orthotope/dim_domain.h index c940745a78..6bd63faeae 100644 --- a/lib/utils/include/utils/orthotope/dim_domain.h +++ b/lib/utils/include/utils/orthotope/dim_domain.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_DIM_DOMAIN_H #include "utils/containers/filter.h" -#include "utils/containers/keys.h" +#include "utils/containers/unordered_keys.h" #include "utils/containers/map_from_keys_and_values.h" #include "utils/containers/restrict_keys.h" #include "utils/containers/set_minus.h" @@ -27,7 +27,7 @@ nonnegative_int dim_domain_num_dims(DimDomain const &domain) { template std::unordered_set get_domain_dims(DimDomain const &domain) { - return keys(domain.dims); + return unordered_keys(domain.dims); } template diff --git a/lib/utils/include/utils/orthotope/dim_projection.dtg.toml b/lib/utils/include/utils/orthotope/dim_projection.dtg.toml index a530adac5d..9133c1a03f 100644 --- a/lib/utils/include/utils/orthotope/dim_projection.dtg.toml +++ b/lib/utils/include/utils/orthotope/dim_projection.dtg.toml @@ -3,6 +3,7 @@ name = "DimProjection" type = "variant" features = [ "eq", + "ord", "hash", "fmt", ] diff --git a/lib/utils/include/utils/orthotope/down_projection.dtg.toml b/lib/utils/include/utils/orthotope/down_projection.dtg.toml index 9a642d2b9f..a83fd04e64 100644 --- a/lib/utils/include/utils/orthotope/down_projection.dtg.toml +++ b/lib/utils/include/utils/orthotope/down_projection.dtg.toml @@ -3,6 +3,7 @@ name = "DownProjection" type = "struct" features = [ "eq", + "ord", "hash", "fmt", ] diff --git a/lib/utils/include/utils/orthotope/down_projection.h b/lib/utils/include/utils/orthotope/down_projection.h index f46a0f16c8..8fb1487ad3 100644 --- a/lib/utils/include/utils/orthotope/down_projection.h +++ b/lib/utils/include/utils/orthotope/down_projection.h @@ -48,7 +48,7 @@ DimCoord compute_down_projection(DownProjection const &projection, output_dims_of_down_projection(projection); return DimCoord{ - generate_map( + generate_unordered_map( output_dims, [&](R const &output_dim) { std::unordered_set src_dims = diff --git a/lib/utils/include/utils/orthotope/eq_projection.dtg.toml b/lib/utils/include/utils/orthotope/eq_projection.dtg.toml index 972952f907..456a44b2d3 100644 --- a/lib/utils/include/utils/orthotope/eq_projection.dtg.toml +++ b/lib/utils/include/utils/orthotope/eq_projection.dtg.toml @@ -3,6 +3,7 @@ name = "EqProjection" type = "struct" features = [ "eq", + "ord", "hash", "fmt", "rapidcheck", diff --git a/lib/utils/include/utils/orthotope/minimal_dim_domain.h b/lib/utils/include/utils/orthotope/minimal_dim_domain.h index 3934e2af62..c9d1214278 100644 --- a/lib/utils/include/utils/orthotope/minimal_dim_domain.h +++ b/lib/utils/include/utils/orthotope/minimal_dim_domain.h @@ -4,8 +4,7 @@ #include "utils/containers/are_disjoint.h" #include "utils/containers/binary_merge_disjoint_maps.h" #include "utils/containers/filtermap_values.h" -#include "utils/containers/generate_map.h" -#include "utils/containers/keys.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/containers/map_from_keys_and_values.h" #include "utils/containers/map_values.h" #include "utils/containers/restrict_keys.h" @@ -16,6 +15,7 @@ #include "utils/orthotope/dim_ordering.dtg.h" #include "utils/orthotope/minimal_dim_domain.dtg.h" #include "utils/orthotope/minimal_orthotope.dtg.h" +#include "utils/containers/unordered_keys.h" namespace FlexFlow { @@ -70,14 +70,14 @@ DimDomain dim_domain_from_minimal_dim_domain( map_values( minimal_dim_domain.dims, [](int_ge_two x) { return x.positive_int_from_int_ge_two(); }), - generate_map(trivial_dims, [](T const &) { return 1_p; })), + generate_unordered_map(trivial_dims, [](T const &) { return 1_p; })), }; } template std::unordered_set get_minimal_domain_dims(MinimalDimDomain const &domain) { - return keys(domain.dims); + return unordered_keys(domain.dims); } template diff --git a/lib/utils/include/utils/orthotope/up_projection.dtg.toml b/lib/utils/include/utils/orthotope/up_projection.dtg.toml index c99e6eec93..7f69c9dd0e 100644 --- a/lib/utils/include/utils/orthotope/up_projection.dtg.toml +++ b/lib/utils/include/utils/orthotope/up_projection.dtg.toml @@ -3,6 +3,7 @@ name = "UpProjection" type = "struct" features = [ "eq", + "ord", "hash", "fmt", ] diff --git a/lib/utils/include/utils/orthotope/up_projection.h b/lib/utils/include/utils/orthotope/up_projection.h index e485419fbb..1e241108e2 100644 --- a/lib/utils/include/utils/orthotope/up_projection.h +++ b/lib/utils/include/utils/orthotope/up_projection.h @@ -24,13 +24,13 @@ UpProjection make_empty_up_projection() { template std::unordered_set input_dims_of_up_projection(UpProjection const &projection) { - return projection.dim_mapping.left_values(); + return unordered_set_of(projection.dim_mapping.left_values()); } template std::unordered_set output_dims_of_up_projection(UpProjection const &projection) { - return projection.dim_mapping.right_values(); + return unordered_set_of(projection.dim_mapping.right_values()); } template diff --git a/lib/utils/src/utils/archetypes/jsonable_ordered_value_type.cc b/lib/utils/src/utils/archetypes/jsonable_ordered_value_type.cc new file mode 100644 index 0000000000..b83da321f7 --- /dev/null +++ b/lib/utils/src/utils/archetypes/jsonable_ordered_value_type.cc @@ -0,0 +1,7 @@ +#include "utils/archetypes/jsonable_ordered_value_type.h" + +namespace FlexFlow { + +template struct jsonable_ordered_value_type<0>; + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/bidict/algorithms/filter_keys.cc b/lib/utils/src/utils/bidict/algorithms/bidict_filter_keys.cc similarity index 58% rename from lib/utils/src/utils/bidict/algorithms/filter_keys.cc rename to lib/utils/src/utils/bidict/algorithms/bidict_filter_keys.cc index 57ef4e873d..5c84ec85b1 100644 --- a/lib/utils/src/utils/bidict/algorithms/filter_keys.cc +++ b/lib/utils/src/utils/bidict/algorithms/bidict_filter_keys.cc @@ -1,4 +1,4 @@ -#include "utils/bidict/algorithms/filter_keys.h" +#include "utils/bidict/algorithms/bidict_filter_keys.h" #include "utils/archetypes/value_type.h" namespace FlexFlow { @@ -7,6 +7,6 @@ using K = value_type<0>; using V = value_type<1>; using F = std::function; -template bidict filter_keys(bidict const &, F &&); +template bidict bidict_filter_keys(bidict const &, F &&); } // namespace FlexFlow diff --git a/lib/utils/src/utils/bidict/algorithms/filter_values.cc b/lib/utils/src/utils/bidict/algorithms/bidict_filter_values.cc similarity index 57% rename from lib/utils/src/utils/bidict/algorithms/filter_values.cc rename to lib/utils/src/utils/bidict/algorithms/bidict_filter_values.cc index 4cf58037ee..7adf808f85 100644 --- a/lib/utils/src/utils/bidict/algorithms/filter_values.cc +++ b/lib/utils/src/utils/bidict/algorithms/bidict_filter_values.cc @@ -1,4 +1,4 @@ -#include "utils/bidict/algorithms/filter_values.h" +#include "utils/bidict/algorithms/bidict_filter_values.h" #include "utils/archetypes/value_type.h" namespace FlexFlow { @@ -7,6 +7,6 @@ using K = value_type<0>; using V = value_type<1>; using F = std::function; -template bidict filter_values(bidict const &, F &&); +template bidict bidict_filter_values(bidict const &, F &&); } // namespace FlexFlow diff --git a/lib/utils/src/utils/bidict/algorithms/filtrans_keys.cc b/lib/utils/src/utils/bidict/algorithms/bidict_filtrans_keys.cc similarity index 61% rename from lib/utils/src/utils/bidict/algorithms/filtrans_keys.cc rename to lib/utils/src/utils/bidict/algorithms/bidict_filtrans_keys.cc index 1e506b8a51..954558a66a 100644 --- a/lib/utils/src/utils/bidict/algorithms/filtrans_keys.cc +++ b/lib/utils/src/utils/bidict/algorithms/bidict_filtrans_keys.cc @@ -1,4 +1,4 @@ -#include "utils/bidict/algorithms/filtrans_keys.h" +#include "utils/bidict/algorithms/bidict_filtrans_keys.h" #include "utils/archetypes/value_type.h" namespace FlexFlow { @@ -8,6 +8,6 @@ using V = value_type<1>; using K2 = value_type<2>; using F = std::function(K)>; -template bidict filtrans_keys(bidict const &, F &&); +template bidict bidict_filtrans_keys(bidict const &, F &&); } // namespace FlexFlow diff --git a/lib/utils/src/utils/bidict/algorithms/filtrans_values.cc b/lib/utils/src/utils/bidict/algorithms/bidict_filtrans_values.cc similarity index 61% rename from lib/utils/src/utils/bidict/algorithms/filtrans_values.cc rename to lib/utils/src/utils/bidict/algorithms/bidict_filtrans_values.cc index 8d1352196c..303d40575a 100644 --- a/lib/utils/src/utils/bidict/algorithms/filtrans_values.cc +++ b/lib/utils/src/utils/bidict/algorithms/bidict_filtrans_values.cc @@ -1,4 +1,4 @@ -#include "utils/bidict/algorithms/filtrans_values.h" +#include "utils/bidict/algorithms/bidict_filtrans_values.h" #include "utils/archetypes/value_type.h" namespace FlexFlow { @@ -8,6 +8,6 @@ using V = value_type<1>; using V2 = value_type<2>; using F = std::function(V)>; -template bidict filtrans_values(bidict const &, F &&); +template bidict bidict_filtrans_values(bidict const &, F &&); } // namespace FlexFlow diff --git a/lib/utils/src/utils/bidict/algorithms/bidict_unordered_set_of.cc b/lib/utils/src/utils/bidict/algorithms/bidict_unordered_set_of.cc new file mode 100644 index 0000000000..425c1d099c --- /dev/null +++ b/lib/utils/src/utils/bidict/algorithms/bidict_unordered_set_of.cc @@ -0,0 +1,11 @@ +#include "utils/bidict/algorithms/bidict_unordered_set_of.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; + +std::unordered_set> bidict_unordered_set_of(bidict const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/bidict/algorithms/unordered_set_of.cc b/lib/utils/src/utils/bidict/algorithms/unordered_set_of.cc deleted file mode 100644 index a0bfa1525e..0000000000 --- a/lib/utils/src/utils/bidict/algorithms/unordered_set_of.cc +++ /dev/null @@ -1,11 +0,0 @@ -#include "utils/bidict/algorithms/unordered_set_of.h" -#include "utils/archetypes/value_type.h" - -namespace FlexFlow { - -using K = value_type<0>; -using V = value_type<1>; - -std::unordered_set> unordered_set_of(bidict const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/cli/cli_parse.cc b/lib/utils/src/utils/cli/cli_parse.cc index 36d5837f9c..8f5f81324c 100644 --- a/lib/utils/src/utils/cli/cli_parse.cc +++ b/lib/utils/src/utils/cli/cli_parse.cc @@ -2,7 +2,7 @@ #include "utils/cli/cli_spec.h" #include "utils/containers/contains.h" #include "utils/containers/enumerate.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" namespace FlexFlow { @@ -27,7 +27,7 @@ tl::expected cli_parse_flag(CLISpec const &cli, tl::expected cli_parse(CLISpec const &cli, std::vector const &args) { CLIParseResult result = CLIParseResult{ - generate_map(cli_get_flag_keys(cli), + generate_unordered_map(cli_get_flag_keys(cli), [](CLIFlagKey const &) { return false; }), {}, }; diff --git a/lib/utils/src/utils/containers/filter.cc b/lib/utils/src/utils/containers/filter.cc index dc11d0dffa..4931d97704 100644 --- a/lib/utils/src/utils/containers/filter.cc +++ b/lib/utils/src/utils/containers/filter.cc @@ -1 +1,45 @@ #include "utils/containers/filter.h" +#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using VT0 = value_type<0>; +using VT1 = value_type<1>; +using OVT0 = ordered_value_type<0>; + +template + std::vector filter(std::vector const &, + std::function const &); + +template + std::unordered_set filter(std::unordered_set const &, + std::function const &); + +template + std::unordered_map + filter(std::unordered_map const &, + std::function const &)> const &); + +template + std::set filter( + std::set const &, + std::function const &); + +template + std::map + filter(std::map const &, + std::function const &)> const &); + +template + std::multiset + filter(std::multiset const &, + std::function const &); + +template + std::unordered_multiset + filter(std::unordered_multiset const &, + std::function const &); + + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/generate_map.cc b/lib/utils/src/utils/containers/generate_map.cc index 54bbe13dc9..57caa3d062 100644 --- a/lib/utils/src/utils/containers/generate_map.cc +++ b/lib/utils/src/utils/containers/generate_map.cc @@ -1 +1,13 @@ #include "utils/containers/generate_map.h" +#include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = ordered_value_type<0>; +using V = value_type<1>; +using F = std::function; + +template std::map generate_map(std::vector const &, F &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/generate_unordered_map.cc b/lib/utils/src/utils/containers/generate_unordered_map.cc new file mode 100644 index 0000000000..2287e45632 --- /dev/null +++ b/lib/utils/src/utils/containers/generate_unordered_map.cc @@ -0,0 +1,13 @@ +#include "utils/containers/generate_unordered_map.h" +#include "utils/archetypes/value_type.h" +#include + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; +using F = std::function; + +template std::unordered_map generate_unordered_map(std::unordered_set const &, F &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/group_by.cc b/lib/utils/src/utils/containers/group_by.cc index efd8c2032b..a41ab4dc62 100644 --- a/lib/utils/src/utils/containers/group_by.cc +++ b/lib/utils/src/utils/containers/group_by.cc @@ -4,8 +4,8 @@ namespace FlexFlow { -using K = value_type<0>; -using V = value_type<1>; +using K = ordered_value_type<0>; +using V = ordered_value_type<1>; using F = std::function; template OneToMany group_by(std::unordered_set const &, F &&); diff --git a/lib/utils/src/utils/containers/is_submapeq_of.cc b/lib/utils/src/utils/containers/is_submapeq_of.cc index 567d94fac5..f8fd627b3d 100644 --- a/lib/utils/src/utils/containers/is_submapeq_of.cc +++ b/lib/utils/src/utils/containers/is_submapeq_of.cc @@ -1 +1,12 @@ #include "utils/containers/is_submapeq_of.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; + +bool is_submapeq_of(std::unordered_map const &, std::unordered_map const &); + + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/items.cc b/lib/utils/src/utils/containers/items.cc index 3b1e80452a..193cec9ca8 100644 --- a/lib/utils/src/utils/containers/items.cc +++ b/lib/utils/src/utils/containers/items.cc @@ -1 +1,14 @@ #include "utils/containers/items.h" +#include "utils/archetypes/ordered_value_type.h" +#include +#include + +namespace FlexFlow { + +using K = ordered_value_type<0>; +using V = ordered_value_type<1>; + +template std::set> items(std::unordered_map const &); +template std::set> items(std::map const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/keys.cc b/lib/utils/src/utils/containers/keys.cc index 6c6abadd56..96db33f4c7 100644 --- a/lib/utils/src/utils/containers/keys.cc +++ b/lib/utils/src/utils/containers/keys.cc @@ -1 +1,13 @@ #include "utils/containers/keys.h" +#include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = ordered_value_type<0>; +using V = value_type<1>; + +template std::set keys(std::unordered_map const &); +template std::set keys(std::map const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/map_from_unordered.cc b/lib/utils/src/utils/containers/map_from_unordered.cc new file mode 100644 index 0000000000..11558af765 --- /dev/null +++ b/lib/utils/src/utils/containers/map_from_unordered.cc @@ -0,0 +1,13 @@ +#include "utils/containers/map_from_unordered.h" +#include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = ordered_value_type<0>; +using V = value_type<1>; + +template + std::map map_from_unordered(std::unordered_map const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/map_values.cc b/lib/utils/src/utils/containers/map_values.cc index e26035e8b1..e850ecf31e 100644 --- a/lib/utils/src/utils/containers/map_values.cc +++ b/lib/utils/src/utils/containers/map_values.cc @@ -1,5 +1,6 @@ #include "utils/containers/map_values.h" #include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { @@ -11,4 +12,11 @@ using F = std::function; template std::unordered_map map_values(std::unordered_map const &, F &&); +using KO = ordered_value_type<0>; +using VO = value_type<1>; +using VO2 = value_type<2>; +using FO = std::function; + +template std::map map_values(std::map const &, FO &&); + } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/require_all_of.cc b/lib/utils/src/utils/containers/require_all_of.cc new file mode 100644 index 0000000000..7fbf48c54f --- /dev/null +++ b/lib/utils/src/utils/containers/require_all_of.cc @@ -0,0 +1,34 @@ +#include "utils/containers/require_all_of.h" +#include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/value_type.h" +#include +#include + +namespace FlexFlow { + +using T1 = value_type<0>; +using F1 = std::function; + +template void require_all_of(std::vector const &, F1 &&); +template void require_all_of(std::unordered_set const &, F1 &&); +template void require_all_of(std::unordered_multiset const &, F1 &&); + +using T2 = ordered_value_type<0>; +using F2 = std::function; + +template void require_all_of(std::set const &, F2 &&); +template void require_all_of(std::multiset const &, F2 &&); + +using K3 = value_type<0>; +using V3 = value_type<1>; +using F3 = std::function; + +template void require_all_of(std::unordered_map const &, F3 &&); + +using K4 = ordered_value_type<0>; +using V4 = ordered_value_type<1>; +using F4 = std::function; + +template void require_all_of(std::map const &, F4 &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/require_only_key.cc b/lib/utils/src/utils/containers/require_only_key.cc index ac8c201303..26ec81528a 100644 --- a/lib/utils/src/utils/containers/require_only_key.cc +++ b/lib/utils/src/utils/containers/require_only_key.cc @@ -1,5 +1,6 @@ #include "utils/containers/require_only_key.h" #include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { @@ -8,4 +9,8 @@ using V = value_type<1>; template V require_only_key(std::unordered_map const &, K const &); +using K2 = ordered_value_type<0>; + +template V require_only_key(std::map const &, K2 const &); + } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/transform.cc b/lib/utils/src/utils/containers/transform.cc index 7cd5a56ed4..55255cbc1e 100644 --- a/lib/utils/src/utils/containers/transform.cc +++ b/lib/utils/src/utils/containers/transform.cc @@ -1 +1,47 @@ #include "utils/containers/transform.h" +#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using In = value_type<0>; +using Out = value_type<1>; +using F = std::function; + +template std::vector transform(std::vector const &, F const &); + +template std::unordered_set transform(std::unordered_set const &, F const &); + +template std::unordered_multiset transform(std::unordered_multiset const &, F const &); + +using In2 = ordered_value_type<0>; +using Out2 = ordered_value_type<1>; +using F2 = std::function; + +template std::set transform(std::set const &, F2 const &); + +template std::multiset transform(std::multiset const &v, F2 const &f); + +using F3 = std::function; + +template std::string transform(std::string const &, F3 const &); + +using K = value_type<3>; +using V = value_type<4>; +using K2 = value_type<5>; +using V2 = value_type<6>; + +template std::unordered_map transform(std::unordered_map const &, + std::function(K const &, V const &)> const &); + +using K3 = ordered_value_type<3>; +using V3 = value_type<4>; +using K4 = ordered_value_type<5>; +using V4 = value_type<6>; + +template std::map transform(std::map const &, + std::function(K3 const &, V3 const &)> const &); + +template std::optional transform(std::optional const &o, + std::function const &); +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/unordered_items.cc b/lib/utils/src/utils/containers/unordered_items.cc new file mode 100644 index 0000000000..9b58cfd18e --- /dev/null +++ b/lib/utils/src/utils/containers/unordered_items.cc @@ -0,0 +1,15 @@ +#include "utils/containers/unordered_items.h" +#include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/value_type.h" +#include +#include + +namespace FlexFlow { + +using K = ordered_value_type<0>; +using V = value_type<1>; + +template std::unordered_set> unordered_items(std::unordered_map const &); +template std::unordered_set> unordered_items(std::map const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/unordered_keys.cc b/lib/utils/src/utils/containers/unordered_keys.cc new file mode 100644 index 0000000000..e850b5f460 --- /dev/null +++ b/lib/utils/src/utils/containers/unordered_keys.cc @@ -0,0 +1,13 @@ +#include "utils/containers/unordered_keys.h" +#include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = ordered_value_type<0>; +using V = value_type<1>; + +template std::unordered_set unordered_keys(std::unordered_map const &); +std::unordered_set unordered_keys(std::map const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/unordered_map_from_map.cc b/lib/utils/src/utils/containers/unordered_map_from_map.cc new file mode 100644 index 0000000000..a0ffa034b5 --- /dev/null +++ b/lib/utils/src/utils/containers/unordered_map_from_map.cc @@ -0,0 +1,12 @@ +#include "utils/containers/unordered_map_from_map.h" +#include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = ordered_value_type<0>; +using V = value_type<0>; + +template std::unordered_map unordered_map_from_map(std::map const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/dataflow_graph_as_dot.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/dataflow_graph_as_dot.cc index f617c52593..c9889a49e7 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/algorithms/dataflow_graph_as_dot.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/dataflow_graph_as_dot.cc @@ -1,5 +1,4 @@ #include "utils/graph/dataflow_graph/algorithms/dataflow_graph_as_dot.h" -#include "utils/containers/generate_map.h" #include "utils/containers/map_keys.h" #include "utils/dot/dot_file.h" #include "utils/dot/dot_html_from_json.h" diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_dominators_map.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_dominators_map.cc index 2da1b208f4..5a9ca06c5e 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_dominators_map.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_dominators_map.cc @@ -1,5 +1,5 @@ #include "utils/graph/digraph/algorithms/get_dominators_map.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/containers/restrict_keys.h" #include "utils/containers/transform.h" #include "utils/containers/values.h" @@ -25,7 +25,7 @@ std::unordered_map> } std::unordered_map> result = - generate_map(get_nodes(g), [&](Node const &) { return get_nodes(g); }); + generate_unordered_map(get_nodes(g), [&](Node const &) { return get_nodes(g); }); while (!queue.empty()) { Node n = queue.front(); queue.pop(); diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_imm_dominators_map.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_imm_dominators_map.cc index 34cc7fcc6f..42a9830837 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_imm_dominators_map.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_imm_dominators_map.cc @@ -1,7 +1,7 @@ #include "utils/graph/digraph/algorithms/get_imm_dominators_map.h" #include "utils/containers/concat_vectors.h" #include "utils/containers/filter_values.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/containers/get_element_counts.h" #include "utils/containers/get_only.h" #include "utils/containers/keys.h" @@ -27,14 +27,14 @@ std::unordered_map> })); std::unordered_map dominator_counts = get_element_counts(recursive_dominator_list); - std::unordered_set imm_dominators = keys( + std::unordered_set imm_dominators = unordered_keys( filter_values(dominator_counts, [](int count) { return count <= 1; })); - assert(imm_dominators.size() <= 1); + ASSERT(imm_dominators.size() <= 1); return maybe_get_only(imm_dominators); }; - return generate_map(get_nodes(g), get_imm_dominator); + return generate_unordered_map(get_nodes(g), get_imm_dominator); } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_imm_post_dominator.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_imm_post_dominator.cc index 39523f2ec1..899ed385c7 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_imm_post_dominator.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_imm_post_dominator.cc @@ -1,5 +1,5 @@ #include "utils/graph/digraph/algorithms/get_imm_post_dominator.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/containers/get_one_of.h" #include "utils/containers/get_only.h" #include "utils/containers/intersection.h" @@ -33,7 +33,7 @@ std::optional Node contracted_node = get_one_of(nodes); std::unordered_map contraction = - generate_map(nodes, [&](Node const &) { return contracted_node; }); + generate_unordered_map(nodes, [&](Node const &) { return contracted_node; }); return get_imm_post_dominator(apply_contraction(g, contraction), contracted_node); } diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_incoming_edges.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_incoming_edges.cc index db09dd07d6..e6c9d5e557 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_incoming_edges.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_incoming_edges.cc @@ -2,6 +2,8 @@ #include "utils/containers/group_by.h" #include "utils/containers/map_values.h" #include "utils/containers/set_of.h" +#include "utils/nonempty_unordered_set/nonempty_unordered_set.h" +#include "utils/containers/unordered_map_from_map.h" namespace FlexFlow { @@ -16,14 +18,18 @@ std::unordered_set get_incoming_edges(DiGraphView const &g, std::unordered_map> get_incoming_edges(DiGraphView const &g, std::unordered_set const &ns) { - std::unordered_map> result = - map_values(group_by(g.query_edges(DirectedEdgeQuery{ - query_set::matchall(), - query_set::match_values_in(set_of(ns)), - }), - [](DirectedEdge const &e) { return e.dst; }) - .l_to_r(), - [](nonempty_unordered_set const &s) + + std::map> by_dst = + group_by(g.query_edges(DirectedEdgeQuery{ + query_set::matchall(), + query_set::match_values_in(set_of(ns)), + }), + [](DirectedEdge const &e) { return e.dst; }) + .l_to_r(); + + std::map> result = + map_values(by_dst, + [](nonempty_set const &s) -> std::unordered_set { return s.unwrap_as_unordered_set(); }); @@ -32,7 +38,7 @@ std::unordered_map> result[n]; } - return result; + return unordered_map_from_map(result); } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_outgoing_edges.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_outgoing_edges.cc index c2057472cf..883d8e3725 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_outgoing_edges.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_outgoing_edges.cc @@ -2,21 +2,26 @@ #include "utils/containers/group_by.h" #include "utils/containers/map_values.h" #include "utils/containers/set_of.h" +#include "utils/nonempty_unordered_set/nonempty_unordered_set.h" +#include "utils/containers/unordered_map_from_map.h" namespace FlexFlow { std::unordered_map> get_outgoing_edges(DiGraphView const &g, std::unordered_set const &ns) { - std::unordered_map> result = - map_values(group_by(g.query_edges(DirectedEdgeQuery{ - query_set::match_values_in(set_of(ns)), - query_set::matchall(), - }), - [](DirectedEdge const &e) { return e.src; }) - .l_to_r(), - [](nonempty_unordered_set const &s) - -> std::unordered_set { + + std::map> by_src = + group_by(g.query_edges(DirectedEdgeQuery{ + query_set::match_values_in(set_of(ns)), + query_set::matchall(), + }), + [](DirectedEdge const &e) { return e.src; }) + .l_to_r(); + + std::map> result = + map_values(by_src, + [](nonempty_set const &s) -> std::unordered_set { return s.unwrap_as_unordered_set(); }); @@ -24,7 +29,7 @@ std::unordered_map> result[n]; } - return result; + return unordered_map_from_map(result); } std::unordered_set get_outgoing_edges(DiGraphView const &g, diff --git a/lib/utils/src/utils/graph/digraph/algorithms/is_acyclic.cc b/lib/utils/src/utils/graph/digraph/algorithms/is_acyclic.cc index 096efd49e9..66c04ec59c 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/is_acyclic.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/is_acyclic.cc @@ -1,5 +1,5 @@ #include "utils/graph/digraph/algorithms/is_acyclic.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/graph/digraph/algorithms/get_successors.h" #include "utils/graph/node/algorithms.h" #include @@ -11,7 +11,7 @@ enum class ExplorationStatus { NOT_EXPLORED, BEING_EXPLORED, FULLY_EXPLORED }; bool is_acyclic(DiGraphView const &g) { std::unordered_map status = - generate_map(get_nodes(g), [](Node const &n) { + generate_unordered_map(get_nodes(g), [](Node const &n) { return ExplorationStatus::NOT_EXPLORED; }); diff --git a/lib/utils/src/utils/graph/instances/adjacency_multidigraph.cc b/lib/utils/src/utils/graph/instances/adjacency_multidigraph.cc index 941c8e8e3e..903ba0c589 100644 --- a/lib/utils/src/utils/graph/instances/adjacency_multidigraph.cc +++ b/lib/utils/src/utils/graph/instances/adjacency_multidigraph.cc @@ -1,12 +1,12 @@ #include "utils/graph/instances/adjacency_multidigraph.h" #include "utils/containers/contains_key.h" #include "utils/containers/extend.h" -#include "utils/containers/generate_map.h" -#include "utils/containers/keys.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/containers/values.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" #include "utils/graph/node/algorithms.h" #include "utils/hash/unordered_set.h" +#include "utils/containers/unordered_keys.h" namespace FlexFlow { @@ -26,8 +26,8 @@ AdjacencyMultiDiGraph::AdjacencyMultiDiGraph( Node AdjacencyMultiDiGraph::add_node() { Node new_node = this->node_source.new_node(); std::unordered_set all_nodes = - set_union(keys(this->adjacency), {new_node}); - this->adjacency[new_node] = generate_map(all_nodes, [](Node const &) { + set_union(unordered_keys(this->adjacency), {new_node}); + this->adjacency[new_node] = generate_unordered_map(all_nodes, [](Node const &) { return std::unordered_set{}; }); @@ -77,15 +77,15 @@ void AdjacencyMultiDiGraph::remove_edge(MultiDiEdge const &e) { std::unordered_set AdjacencyMultiDiGraph::query_nodes(NodeQuery const &q) const { - return apply_query(q.nodes, keys(this->adjacency)); + return apply_query(q.nodes, unordered_keys(this->adjacency)); } std::unordered_set AdjacencyMultiDiGraph::query_edges(MultiDiEdgeQuery const &q) const { std::unordered_set result; - std::unordered_set srcs = apply_query(q.srcs, keys(this->adjacency)); - std::unordered_set dsts = apply_query(q.dsts, keys(this->adjacency)); + std::unordered_set srcs = apply_query(q.srcs, unordered_keys(this->adjacency)); + std::unordered_set dsts = apply_query(q.dsts, unordered_keys(this->adjacency)); for (Node const &src : srcs) { for (Node const &dst : dsts) { extend(result, this->adjacency.at(src).at(dst)); @@ -108,8 +108,8 @@ void AdjacencyMultiDiGraph::inplace_materialize_from( std::unordered_set nodes = get_nodes(g); std::unordered_set edges = get_edges(g); - this->adjacency = generate_map(nodes, [&](Node const &) { - return generate_map( + this->adjacency = generate_unordered_map(nodes, [&](Node const &) { + return generate_unordered_map( nodes, [&](Node const &) { return std::unordered_set{}; }); }); this->edge_nodes.clear(); diff --git a/lib/utils/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc b/lib/utils/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc index db181fbe73..ddc3f4f7ae 100644 --- a/lib/utils/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc +++ b/lib/utils/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc @@ -7,6 +7,7 @@ #include "utils/graph/multidigraph/multidiedge_query.dtg.h" #include "utils/graph/node/algorithms.h" #include "utils/graph/query_set.h" +#include "utils/containers/unordered_map_from_map.h" namespace FlexFlow { @@ -28,11 +29,11 @@ std::unordered_map> query_set::match_values_in(set_of(ns)), }; - std::unordered_map> result = map_values( + std::map> result = map_values( group_by(g.query_edges(query), [&](MultiDiEdge const &e) { return g.get_multidiedge_dst(e); }) .l_to_r(), - [](nonempty_unordered_set const &s) + [](nonempty_set const &s) -> std::unordered_set { return s.unwrap_as_unordered_set(); }); @@ -41,7 +42,7 @@ std::unordered_map> result[n]; } - return result; + return unordered_map_from_map(result); } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/multidigraph/algorithms/get_multidiedge_to_diedge_map.cc b/lib/utils/src/utils/graph/multidigraph/algorithms/get_multidiedge_to_diedge_map.cc index 826c03f476..466bb903d9 100644 --- a/lib/utils/src/utils/graph/multidigraph/algorithms/get_multidiedge_to_diedge_map.cc +++ b/lib/utils/src/utils/graph/multidigraph/algorithms/get_multidiedge_to_diedge_map.cc @@ -1,5 +1,5 @@ #include "utils/graph/multidigraph/algorithms/get_multidiedge_to_diedge_map.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/graph/multidigraph/algorithms/get_directed_edge.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" @@ -7,7 +7,7 @@ namespace FlexFlow { std::unordered_map get_multidiedge_to_diedge_map(MultiDiGraphView const &g) { - return generate_map(get_edges(g), [&](MultiDiEdge const &e) { + return generate_unordered_map(get_edges(g), [&](MultiDiEdge const &e) { return get_directed_edge(g, e); }); } diff --git a/lib/utils/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc b/lib/utils/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc index 28e181ebb9..143e59b3db 100644 --- a/lib/utils/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc +++ b/lib/utils/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc @@ -5,6 +5,7 @@ #include "utils/graph/multidigraph/algorithms/get_edges.h" #include "utils/graph/node/algorithms.h" #include +#include "utils/containers/unordered_map_from_map.h" namespace FlexFlow { @@ -26,11 +27,11 @@ std::unordered_map> query_set::matchall(), }; - std::unordered_map> result = map_values( + std::map> result = map_values( group_by(g.query_edges(query), [&](MultiDiEdge const &e) { return g.get_multidiedge_src(e); }) .l_to_r(), - [](nonempty_unordered_set const &s) + [](nonempty_set const &s) -> std::unordered_set { return s.unwrap_as_unordered_set(); }); @@ -39,7 +40,7 @@ std::unordered_map> result[n]; } - return result; + return unordered_map_from_map(result); } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.cc index 0228fdd8e9..b68aa4b1d8 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.cc @@ -1,5 +1,5 @@ #include "utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/containers/sorted_by.h" #include "utils/containers/transform.h" #include "utils/graph/dataflow_graph/dataflow_edge_query.h" @@ -45,7 +45,7 @@ std::vector get_incoming_edges(OpenDataflowGraphView const &g, std::unordered_map> get_incoming_edges(OpenDataflowGraphView const &g, std::unordered_set const &ns) { - return generate_map(ns, + return generate_unordered_map(ns, [&](Node const &n) { return get_incoming_edges(g, n); }); } diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/balanced_binary_sp_tree_from_nary.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/balanced_binary_sp_tree_from_nary.cc index 6f1c2ace68..398abb3faf 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/balanced_binary_sp_tree_from_nary.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/balanced_binary_sp_tree_from_nary.cc @@ -13,6 +13,7 @@ #include "utils/overload.h" #include #include +#include "utils/containers/multiset_of.h" namespace FlexFlow { @@ -43,8 +44,8 @@ BinarySPDecompositionTree from_parallel_child(children[0]), from_parallel_child(children[1])}}; } - auto s1 = unordered_multiset_of(slice(children, 0, children.size() / 2)); - auto s2 = unordered_multiset_of( + auto s1 = multiset_of(slice(children, 0, children.size() / 2)); + auto s2 = multiset_of( slice(children, children.size() / 2, std::nullopt)); return BinarySPDecompositionTree{BinaryParallelSplit{ diff --git a/lib/utils/src/utils/graph/series_parallel/non_normal_sp_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/non_normal_sp_decomposition.cc index 6105dda704..b5a07a5a4d 100644 --- a/lib/utils/src/utils/graph/series_parallel/non_normal_sp_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/non_normal_sp_decomposition.cc @@ -11,6 +11,8 @@ #include "utils/graph/series_parallel/series_split.dtg.h" #include "utils/overload.h" #include "utils/variant.h" +#include "utils/containers/unordered_multiset_of.h" +#include "utils/containers/multiset_of.h" namespace FlexFlow { @@ -43,7 +45,7 @@ NonNormalSPDecomposition non_normal_parallel_composition( for (NonNormalSPDecomposition const &sp_comp : sp_compositions) { if (sp_comp.has()) { composition = multiset_union( - composition, sp_comp.get().get_children()); + composition, unordered_multiset_of(sp_comp.get().get_children())); } else if (sp_comp.has()) { composition.insert(sp_comp.get()); } else { @@ -51,7 +53,7 @@ NonNormalSPDecomposition non_normal_parallel_composition( composition.insert(sp_comp.get()); } } - return NonNormalSPDecomposition(NonNormalParallelSplit{composition}); + return NonNormalSPDecomposition(NonNormalParallelSplit{multiset_of(composition)}); } static Node as_non_normal(Node const &n) { @@ -70,11 +72,11 @@ static NonNormalSeriesSplit as_non_normal(SeriesSplit const &s) { static NonNormalParallelSplit as_non_normal(ParallelSplit const &p) { return non_normal_parallel_composition( - transform(p.get_children(), + unordered_multiset_of(transform(p.get_children(), [](std::variant const &child) { return as_non_normal( widen(child)); - })) + }))) .get(); } diff --git a/lib/utils/src/utils/graph/series_parallel/normalize_sp_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/normalize_sp_decomposition.cc index 3851ca38d9..5eda579f81 100644 --- a/lib/utils/src/utils/graph/series_parallel/normalize_sp_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/normalize_sp_decomposition.cc @@ -6,6 +6,7 @@ #include "utils/graph/series_parallel/non_normal_sp_decomposition.h" #include "utils/graph/series_parallel/series_parallel_decomposition.h" #include "utils/variant.h" +#include "utils/containers/unordered_multiset_of.h" namespace FlexFlow { @@ -41,7 +42,7 @@ static SeriesParallelDecomposition static SeriesParallelDecomposition normalize_sp_decomposition(NonNormalParallelSplit const ¶llel) { - std::unordered_multiset normalized_children = + std::multiset normalized_children = transform(filter_empty(parallel.get_children()), [](std::variant const &child) { return normalize_sp_decomposition( @@ -54,7 +55,7 @@ static SeriesParallelDecomposition if (normalized_children.size() == 1) { return get_only(normalized_children); } - return parallel_composition(normalized_children); + return parallel_composition(unordered_multiset_of(normalized_children)); } SeriesParallelDecomposition diff --git a/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc index e0075c2584..8c9f655fe5 100644 --- a/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc @@ -16,6 +16,7 @@ #include "utils/nonnegative_int/nonnegative_int.h" #include "utils/variant.h" #include +#include "utils/containers/multiset_of.h" namespace FlexFlow { @@ -31,7 +32,7 @@ struct ToFinalAST { .value(); })}; } else { - return ParallelSplit{unordered_multiset_of(transform( + return ParallelSplit{multiset_of(transform( node.children, [](std::variant const &s) { return narrow>( @@ -137,7 +138,7 @@ SeriesParallelDecomposition parallel_composition( for (SeriesParallelDecomposition const &sp_comp : sp_compositions) { if (sp_comp.has()) { composition = multiset_union(composition, - sp_comp.get().get_children()); + unordered_multiset_of(sp_comp.get().get_children())); } else if (sp_comp.has()) { composition.insert(sp_comp.get()); } else { @@ -145,7 +146,7 @@ SeriesParallelDecomposition parallel_composition( composition.insert(sp_comp.get()); } } - return SeriesParallelDecomposition(ParallelSplit{composition}); + return SeriesParallelDecomposition(ParallelSplit{multiset_of(composition)}); } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/series_parallel_metrics.cc b/lib/utils/src/utils/graph/series_parallel/series_parallel_metrics.cc index fc7cad225a..590912e93f 100644 --- a/lib/utils/src/utils/graph/series_parallel/series_parallel_metrics.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_parallel_metrics.cc @@ -5,6 +5,7 @@ #include "utils/containers/values.h" #include "utils/containers/vector_of.h" #include "utils/fmt/unordered_multiset.h" +#include "utils/fmt/multiset.h" #include "utils/graph/digraph/algorithms/get_edges.h" #include "utils/graph/digraph/algorithms/get_longest_path_lengths_from_root.h" #include "utils/graph/digraph/digraph_view.h" @@ -15,6 +16,7 @@ #include "utils/nonnegative_int/nonnegative_int.h" #include "utils/variant.h" #include + namespace FlexFlow { static std::unordered_map diff --git a/lib/utils/src/utils/graph/series_parallel/sp_ization/escribano_algo.cc b/lib/utils/src/utils/graph/series_parallel/sp_ization/escribano_algo.cc index 36b8e7294b..e3e24294b9 100644 --- a/lib/utils/src/utils/graph/series_parallel/sp_ization/escribano_algo.cc +++ b/lib/utils/src/utils/graph/series_parallel/sp_ization/escribano_algo.cc @@ -147,7 +147,7 @@ static std::unordered_set return filter_out_sync_nodes(forest, node_roles); } -static std::pair, nonempty_unordered_set> +static std::pair, nonempty_set> get_up_and_down_sets( DiGraph const &g, std::unordered_set const &forest, @@ -229,7 +229,7 @@ SeriesParallelDecomposition escribano_sp_ization(DiGraph g) { std::unordered_set forest = get_forest_escribano(sp, handle, component, node_roles); - std::pair, nonempty_unordered_set> + std::pair, nonempty_set> up_down_sets = get_up_and_down_sets(sp, forest, depth_map); std::unordered_set up = up_down_sets.first.unwrap_as_unordered_set(); diff --git a/lib/utils/src/utils/graph/series_parallel/sp_ization/flexible_algo.cc b/lib/utils/src/utils/graph/series_parallel/sp_ization/flexible_algo.cc index 7206ec5cda..b22d7d3811 100644 --- a/lib/utils/src/utils/graph/series_parallel/sp_ization/flexible_algo.cc +++ b/lib/utils/src/utils/graph/series_parallel/sp_ization/flexible_algo.cc @@ -3,7 +3,7 @@ #include "utils/containers/argmin.h" #include "utils/containers/contains.h" #include "utils/containers/filter.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/containers/get_only.h" #include "utils/containers/intersection.h" #include "utils/containers/is_subseteq_of.h" @@ -233,7 +233,7 @@ static std::unordered_set ASSERT(!candidate_nodes.empty()); std::unordered_map critical_path_costs = - generate_map(candidate_nodes, [&](Node const &node) { + generate_unordered_map(candidate_nodes, [&](Node const &node) { std::unordered_set preds = get_predecessors(g, node); float max_parent_cost = maximum(transform(preds, [&](Node const &pred) { return sp_longest_paths.at(pred); @@ -253,7 +253,7 @@ static std::unordered_set static bool cost_map_is_valid(DiGraphView const &g, std::unordered_map const &cost_map) { - bool has_correct_nodes = get_nodes(g) == keys(cost_map); + bool has_correct_nodes = (get_nodes(g) == unordered_keys(cost_map)); bool has_nonnegative_costs = all_of(values(cost_map), [&](float const &cost) { return cost >= 0.0f; }); return has_correct_nodes && has_nonnegative_costs; diff --git a/lib/utils/src/utils/graph/series_parallel/sp_ization/naive_stratum_sync.cc b/lib/utils/src/utils/graph/series_parallel/sp_ization/naive_stratum_sync.cc index 3c38b23f2b..4ebaf89756 100644 --- a/lib/utils/src/utils/graph/series_parallel/sp_ization/naive_stratum_sync.cc +++ b/lib/utils/src/utils/graph/series_parallel/sp_ization/naive_stratum_sync.cc @@ -1,6 +1,5 @@ #include "utils/graph/series_parallel/sp_ization/naive_stratum_sync.h" #include "utils/containers/group_by.h" -#include "utils/containers/keys.h" #include "utils/containers/maximum.h" #include "utils/containers/range.h" #include "utils/containers/transform.h" @@ -13,6 +12,7 @@ #include "utils/graph/series_parallel/series_parallel_decomposition.h" #include "utils/graph/series_parallel/sp_ization/dependencies_are_maintained.h" #include +#include "utils/containers/unordered_keys.h" namespace FlexFlow { @@ -21,7 +21,7 @@ std::vector> std::unordered_map node_to_stratum = get_longest_path_lengths_from_root(g); - std::unordered_set nodes = keys(node_to_stratum); + std::unordered_set nodes = unordered_keys(node_to_stratum); OneToMany strata_to_nodes = group_by(nodes, [&](Node const &n) { return node_to_stratum.at(n); }); diff --git a/lib/utils/src/utils/graph/series_parallel/sp_ization/node_role.cc b/lib/utils/src/utils/graph/series_parallel/sp_ization/node_role.cc index 97b8f11ec3..a6d4183a23 100644 --- a/lib/utils/src/utils/graph/series_parallel/sp_ization/node_role.cc +++ b/lib/utils/src/utils/graph/series_parallel/sp_ization/node_role.cc @@ -1,5 +1,5 @@ #include "utils/graph/series_parallel/sp_ization/node_role.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/graph/algorithms.h" #include "utils/graph/digraph/algorithms/get_predecessors.h" #include "utils/graph/digraph/algorithms/get_successors.h" @@ -10,7 +10,7 @@ namespace FlexFlow { std::unordered_map get_initial_node_role_map(DiGraphView const &g) { - return generate_map(get_nodes(g), + return generate_unordered_map(get_nodes(g), [](Node const &) { return NodeRole::PURE; }); } diff --git a/lib/utils/src/utils/many_to_one/exhaustive_relational_join.cc b/lib/utils/src/utils/many_to_one/exhaustive_relational_join.cc index a12de02656..b987153285 100644 --- a/lib/utils/src/utils/many_to_one/exhaustive_relational_join.cc +++ b/lib/utils/src/utils/many_to_one/exhaustive_relational_join.cc @@ -1,11 +1,11 @@ #include "utils/many_to_one/exhaustive_relational_join.h" -#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using T1 = value_type<0>; -using T2 = value_type<1>; -using T3 = value_type<2>; +using T1 = ordered_value_type<0>; +using T2 = ordered_value_type<1>; +using T3 = ordered_value_type<2>; template ManyToOne exhaustive_relational_join(ManyToOne const &, diff --git a/lib/utils/src/utils/many_to_one/invert_many_to_one.cc b/lib/utils/src/utils/many_to_one/invert_many_to_one.cc index 92570a1c7f..adb63f5fd8 100644 --- a/lib/utils/src/utils/many_to_one/invert_many_to_one.cc +++ b/lib/utils/src/utils/many_to_one/invert_many_to_one.cc @@ -1,10 +1,10 @@ #include "utils/many_to_one/invert_many_to_one.h" -#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using L = value_type<0>; -using R = value_type<1>; +using L = ordered_value_type<0>; +using R = ordered_value_type<1>; template OneToMany invert_many_to_one(ManyToOne const &); diff --git a/lib/utils/src/utils/many_to_one/many_to_one.cc b/lib/utils/src/utils/many_to_one/many_to_one.cc index bbb3bfcc14..52ae52153b 100644 --- a/lib/utils/src/utils/many_to_one/many_to_one.cc +++ b/lib/utils/src/utils/many_to_one/many_to_one.cc @@ -1,5 +1,5 @@ #include "utils/many_to_one/many_to_one.h" -#include "utils/archetypes/jsonable_value_type.h" +#include "utils/archetypes/jsonable_ordered_value_type.h" #include "utils/archetypes/rapidcheckable_value_type.h" #include "utils/archetypes/value_type.h" @@ -17,12 +17,18 @@ template std::unordered_map, R> template std::ostream &operator<<(std::ostream &, ManyToOne const &); +template std::unordered_set> + unstructured_relation_from_many_to_one(ManyToOne const &); + +template ManyToOne many_to_one_from_unstructured_relation( + std::unordered_set> const &); + } // namespace FlexFlow namespace nlohmann { -using L = ::FlexFlow::jsonable_value_type<0>; -using R = ::FlexFlow::jsonable_value_type<1>; +using L = ::FlexFlow::jsonable_ordered_value_type<0>; +using R = ::FlexFlow::jsonable_ordered_value_type<1>; template struct adl_serializer<::FlexFlow::ManyToOne>; diff --git a/lib/utils/src/utils/many_to_one/many_to_one_from_unstructured_relation.cc b/lib/utils/src/utils/many_to_one/many_to_one_from_unstructured_relation.cc deleted file mode 100644 index dc03030f20..0000000000 --- a/lib/utils/src/utils/many_to_one/many_to_one_from_unstructured_relation.cc +++ /dev/null @@ -1,12 +0,0 @@ -#include "utils/many_to_one/many_to_one_from_unstructured_relation.h" -#include "utils/archetypes/value_type.h" - -namespace FlexFlow { - -using L = value_type<0>; -using R = value_type<1>; - -template ManyToOne many_to_one_from_unstructured_relation( - std::unordered_set> const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/many_to_one/unstructured_relation_from_many_to_one.cc b/lib/utils/src/utils/many_to_one/unstructured_relation_from_many_to_one.cc deleted file mode 100644 index d89df51b79..0000000000 --- a/lib/utils/src/utils/many_to_one/unstructured_relation_from_many_to_one.cc +++ /dev/null @@ -1,12 +0,0 @@ -#include "utils/many_to_one/unstructured_relation_from_many_to_one.h" -#include "utils/archetypes/value_type.h" - -namespace FlexFlow { - -using L = value_type<0>; -using R = value_type<1>; - -template std::unordered_set> - unstructured_relation_from_many_to_one(ManyToOne const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/nonempty_set/nonempty_set.cc b/lib/utils/src/utils/nonempty_set/nonempty_set.cc new file mode 100644 index 0000000000..1af2951f10 --- /dev/null +++ b/lib/utils/src/utils/nonempty_set/nonempty_set.cc @@ -0,0 +1,23 @@ +#include "utils/nonempty_set/nonempty_set.h" +#include "utils/archetypes/ordered_value_type.h" + +using T = ::FlexFlow::ordered_value_type<0>; + +namespace FlexFlow { + +template struct nonempty_set; + +template bool operator==(std::set const &, nonempty_set const &); + +template bool operator!=(std::set const &, nonempty_set const &); + +template std::set format_as(nonempty_set const &); +template std::ostream &operator<<(std::ostream &, nonempty_set const &); + +} // namespace FlexFlow + +namespace std { + +template struct hash<::FlexFlow::nonempty_set>; + +} // namespace std diff --git a/lib/utils/src/utils/one_to_many/exhaustive_relational_join.cc b/lib/utils/src/utils/one_to_many/exhaustive_relational_join.cc index 6d237732e9..ae8fdc0d99 100644 --- a/lib/utils/src/utils/one_to_many/exhaustive_relational_join.cc +++ b/lib/utils/src/utils/one_to_many/exhaustive_relational_join.cc @@ -1,11 +1,11 @@ #include "utils/one_to_many/exhaustive_relational_join.h" -#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using T1 = value_type<0>; -using T2 = value_type<1>; -using T3 = value_type<2>; +using T1 = ordered_value_type<0>; +using T2 = ordered_value_type<1>; +using T3 = ordered_value_type<2>; template OneToMany exhaustive_relational_join(OneToMany const &, diff --git a/lib/utils/src/utils/one_to_many/invert_one_to_many.cc b/lib/utils/src/utils/one_to_many/invert_one_to_many.cc index cb911ff60a..45edf29b3c 100644 --- a/lib/utils/src/utils/one_to_many/invert_one_to_many.cc +++ b/lib/utils/src/utils/one_to_many/invert_one_to_many.cc @@ -1,10 +1,10 @@ #include "utils/one_to_many/invert_one_to_many.h" -#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using L = value_type<0>; -using R = value_type<1>; +using L = ordered_value_type<0>; +using R = ordered_value_type<1>; template ManyToOne invert_one_to_many(OneToMany const &); diff --git a/lib/utils/src/utils/one_to_many/one_to_many.cc b/lib/utils/src/utils/one_to_many/one_to_many.cc index 158d2e10c9..ce6220c509 100644 --- a/lib/utils/src/utils/one_to_many/one_to_many.cc +++ b/lib/utils/src/utils/one_to_many/one_to_many.cc @@ -1,28 +1,32 @@ #include "utils/one_to_many/one_to_many.h" #include "utils/archetypes/jsonable_value_type.h" #include "utils/archetypes/rapidcheckable_value_type.h" -#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/jsonable_ordered_value_type.h" using namespace ::FlexFlow; namespace FlexFlow { -using L = value_type<0>; -using R = value_type<1>; +using L = ordered_value_type<0>; +using R = ordered_value_type<1>; template struct OneToMany; -template std::unordered_map> +template std::map> format_as(OneToMany const &); template std::ostream &operator<<(std::ostream &, OneToMany const &); +template std::unordered_set> + unstructured_relation_from_one_to_many(OneToMany const &); + } // namespace FlexFlow namespace nlohmann { -using L = ::FlexFlow::jsonable_value_type<0>; -using R = ::FlexFlow::jsonable_value_type<1>; +using L = ::FlexFlow::jsonable_ordered_value_type<0>; +using R = ::FlexFlow::jsonable_ordered_value_type<1>; template struct adl_serializer<::FlexFlow::OneToMany>; @@ -40,8 +44,8 @@ template struct Arbitrary<::FlexFlow::OneToMany>; namespace std { -using L = ::FlexFlow::value_type<0>; -using R = ::FlexFlow::value_type<1>; +using L = ::FlexFlow::ordered_value_type<0>; +using R = ::FlexFlow::ordered_value_type<1>; template struct hash>; diff --git a/lib/utils/src/utils/one_to_many/one_to_many_filter_keys.cc b/lib/utils/src/utils/one_to_many/one_to_many_filter_keys.cc new file mode 100644 index 0000000000..a8855e1857 --- /dev/null +++ b/lib/utils/src/utils/one_to_many/one_to_many_filter_keys.cc @@ -0,0 +1,12 @@ +#include "utils/one_to_many/one_to_many_filter_keys.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using L = ordered_value_type<0>; +using R = ordered_value_type<1>; +using F = std::function; + +template OneToMany one_to_many_filter_keys(OneToMany const &, F &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/one_to_many/one_to_many_filter_values.cc b/lib/utils/src/utils/one_to_many/one_to_many_filter_values.cc new file mode 100644 index 0000000000..fa11ccf56f --- /dev/null +++ b/lib/utils/src/utils/one_to_many/one_to_many_filter_values.cc @@ -0,0 +1,13 @@ +#include "utils/one_to_many/one_to_many_filter_values.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using L = ordered_value_type<0>; +using R = ordered_value_type<1>; + +using F = std::function; + +template OneToMany one_to_many_filter_values(OneToMany const &, F &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/one_to_many/one_to_many_from_bidict.cc b/lib/utils/src/utils/one_to_many/one_to_many_from_bidict.cc index bd6c976488..3dad3d6a1d 100644 --- a/lib/utils/src/utils/one_to_many/one_to_many_from_bidict.cc +++ b/lib/utils/src/utils/one_to_many/one_to_many_from_bidict.cc @@ -1,10 +1,10 @@ #include "utils/one_to_many/one_to_many_from_bidict.h" -#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using L = value_type<0>; -using R = value_type<1>; +using L = ordered_value_type<0>; +using R = ordered_value_type<1>; template OneToMany one_to_many_from_bidict(bidict const &); diff --git a/lib/utils/src/utils/one_to_many/one_to_many_from_l_to_r_mapping.cc b/lib/utils/src/utils/one_to_many/one_to_many_from_l_to_r_mapping.cc index 76f0a221c5..124adb20c3 100644 --- a/lib/utils/src/utils/one_to_many/one_to_many_from_l_to_r_mapping.cc +++ b/lib/utils/src/utils/one_to_many/one_to_many_from_l_to_r_mapping.cc @@ -1,10 +1,10 @@ #include "utils/one_to_many/one_to_many_from_l_to_r_mapping.h" -#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using L = value_type<0>; -using R = value_type<1>; +using L = ordered_value_type<0>; +using R = ordered_value_type<1>; template OneToMany one_to_many_from_l_to_r_mapping( std::unordered_map> const &); diff --git a/lib/utils/src/utils/one_to_many/one_to_many_from_unstructured_relation.cc b/lib/utils/src/utils/one_to_many/one_to_many_from_unstructured_relation.cc deleted file mode 100644 index 4fdd52ef2e..0000000000 --- a/lib/utils/src/utils/one_to_many/one_to_many_from_unstructured_relation.cc +++ /dev/null @@ -1,12 +0,0 @@ -#include "utils/one_to_many/one_to_many_from_unstructured_relation.h" -#include "utils/archetypes/value_type.h" - -namespace FlexFlow { - -using L = value_type<0>; -using R = value_type<1>; - -template OneToMany one_to_many_from_unstructured_relation( - std::unordered_set> const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/one_to_many/one_to_many_transform_values.cc b/lib/utils/src/utils/one_to_many/one_to_many_transform_values.cc index 141db7f1da..47a4652ce2 100644 --- a/lib/utils/src/utils/one_to_many/one_to_many_transform_values.cc +++ b/lib/utils/src/utils/one_to_many/one_to_many_transform_values.cc @@ -1,11 +1,11 @@ #include "utils/one_to_many/one_to_many_transform_values.h" -#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using L = value_type<0>; -using R1 = value_type<1>; -using R2 = value_type<2>; +using L = ordered_value_type<0>; +using R1 = ordered_value_type<1>; +using R2 = ordered_value_type<2>; using F = std::function; template OneToMany one_to_many_transform_values(OneToMany const &, diff --git a/lib/utils/src/utils/one_to_many/unstructured_relation_from_one_to_many.cc b/lib/utils/src/utils/one_to_many/unstructured_relation_from_one_to_many.cc deleted file mode 100644 index 9a48510a3a..0000000000 --- a/lib/utils/src/utils/one_to_many/unstructured_relation_from_one_to_many.cc +++ /dev/null @@ -1,12 +0,0 @@ -#include "utils/one_to_many/unstructured_relation_from_one_to_many.h" -#include "utils/archetypes/value_type.h" - -namespace FlexFlow { - -using L = value_type<0>; -using R = value_type<1>; - -template std::unordered_set> - unstructured_relation_from_one_to_many(OneToMany const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/orthotope/dim_domain_mapping.cc b/lib/utils/src/utils/orthotope/dim_domain_mapping.cc index 03762dadca..bd0f46e3dc 100644 --- a/lib/utils/src/utils/orthotope/dim_domain_mapping.cc +++ b/lib/utils/src/utils/orthotope/dim_domain_mapping.cc @@ -1,9 +1,9 @@ #include "utils/orthotope/dim_domain_mapping.h" -#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" -using ::FlexFlow::value_type; -using L = value_type<0>; -using R = value_type<1>; +using ::FlexFlow::ordered_value_type; +using L = ordered_value_type<0>; +using R = ordered_value_type<1>; namespace FlexFlow { @@ -32,9 +32,9 @@ template DimDomainMapping DimOrdering const &, DimOrdering const &); -using T1 = value_type<2>; -using T2 = value_type<3>; -using T3 = value_type<4>; +using T1 = ordered_value_type<2>; +using T2 = ordered_value_type<3>; +using T3 = ordered_value_type<4>; template DimDomainMapping compose_dim_domain_mappings(DimDomainMapping const &, diff --git a/lib/utils/src/utils/orthotope/dim_projection.cc b/lib/utils/src/utils/orthotope/dim_projection.cc index fdf0472a36..9fa250fa47 100644 --- a/lib/utils/src/utils/orthotope/dim_projection.cc +++ b/lib/utils/src/utils/orthotope/dim_projection.cc @@ -1,10 +1,11 @@ #include "utils/orthotope/dim_projection.h" #include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using L = value_type<0>; -using R = value_type<1>; +using L = ordered_value_type<0>; +using R = ordered_value_type<1>; template DimProjection dim_projection_identity_map(DimDomain const &, @@ -27,9 +28,9 @@ template DimCoord compute_dim_projection(DimProjection const &, DimOrdering const &, DimOrdering const &); -using T1 = value_type<2>; -using T2 = value_type<3>; -using T3 = value_type<4>; +using T1 = ordered_value_type<2>; +using T2 = ordered_value_type<3>; +using T3 = ordered_value_type<4>; template DimProjection right_compose_eq_projection(DimProjection const &, diff --git a/lib/utils/src/utils/orthotope/down_projection.cc b/lib/utils/src/utils/orthotope/down_projection.cc index 73842ecc11..684521a5de 100644 --- a/lib/utils/src/utils/orthotope/down_projection.cc +++ b/lib/utils/src/utils/orthotope/down_projection.cc @@ -1,10 +1,10 @@ #include "utils/orthotope/down_projection.h" -#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using L = value_type<0>; -using R = value_type<1>; +using L = ordered_value_type<0>; +using R = ordered_value_type<1>; template DownProjection make_empty_down_projection(); @@ -26,9 +26,9 @@ template void project_dims(DownProjection &, template UpProjection invert_down_projection(DownProjection const &); -using T1 = value_type<2>; -using T2 = value_type<3>; -using T3 = value_type<4>; +using T1 = ordered_value_type<2>; +using T2 = ordered_value_type<3>; +using T3 = ordered_value_type<4>; template DownProjection compose_down_projections(DownProjection const &, diff --git a/lib/utils/src/utils/orthotope/eq_projection.cc b/lib/utils/src/utils/orthotope/eq_projection.cc index 877b3f93ae..a6965dc36e 100644 --- a/lib/utils/src/utils/orthotope/eq_projection.cc +++ b/lib/utils/src/utils/orthotope/eq_projection.cc @@ -1,10 +1,10 @@ #include "utils/orthotope/eq_projection.h" -#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using L = value_type<0>; -using R = value_type<1>; +using L = ordered_value_type<0>; +using R = ordered_value_type<1>; template EqProjection make_empty_eq_projection(); @@ -18,9 +18,9 @@ template void project_dims(EqProjection &, L const &, R const &); template EqProjection invert_eq_projection(EqProjection const &); -using T1 = value_type<0>; -using T2 = value_type<1>; -using T3 = value_type<2>; +using T1 = ordered_value_type<0>; +using T2 = ordered_value_type<1>; +using T3 = ordered_value_type<2>; template EqProjection compose_eq_projections(EqProjection const &, diff --git a/lib/utils/src/utils/orthotope/minimal_dim_domain_mapping.cc b/lib/utils/src/utils/orthotope/minimal_dim_domain_mapping.cc index a867abfd0a..5d4c47b491 100644 --- a/lib/utils/src/utils/orthotope/minimal_dim_domain_mapping.cc +++ b/lib/utils/src/utils/orthotope/minimal_dim_domain_mapping.cc @@ -1,9 +1,9 @@ #include "utils/orthotope/minimal_dim_domain_mapping.h" -#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" -using ::FlexFlow::value_type; -using L = value_type<0>; -using R = value_type<1>; +using ::FlexFlow::ordered_value_type; +using L = ordered_value_type<0>; +using R = ordered_value_type<1>; namespace FlexFlow { @@ -40,9 +40,9 @@ template MinimalDimDomainMapping DimOrdering const &, DimOrdering const &); -using T1 = value_type<2>; -using T2 = value_type<3>; -using T3 = value_type<4>; +using T1 = ordered_value_type<2>; +using T2 = ordered_value_type<3>; +using T3 = ordered_value_type<4>; template MinimalDimDomainMapping compose_minimal_dim_domain_mappings( MinimalDimDomainMapping const &, diff --git a/lib/utils/src/utils/orthotope/up_projection.cc b/lib/utils/src/utils/orthotope/up_projection.cc index 604cec08ed..0c8909dffd 100644 --- a/lib/utils/src/utils/orthotope/up_projection.cc +++ b/lib/utils/src/utils/orthotope/up_projection.cc @@ -1,12 +1,11 @@ #include "utils/orthotope/up_projection.h" #include "utils/archetypes/ordered_value_type.h" -#include "utils/archetypes/value_type.h" namespace FlexFlow { -using T1 = value_type<0>; -using T2 = value_type<1>; -using T3 = value_type<2>; +using T1 = ordered_value_type<0>; +using T2 = ordered_value_type<1>; +using T3 = ordered_value_type<2>; template UpProjection compose_up_projections(UpProjection const &, diff --git a/lib/utils/test/src/utils/bidict/algorithms/filter_keys.cc b/lib/utils/test/src/utils/bidict/algorithms/bidict_filter_keys.cc similarity index 64% rename from lib/utils/test/src/utils/bidict/algorithms/filter_keys.cc rename to lib/utils/test/src/utils/bidict/algorithms/bidict_filter_keys.cc index 3c3097fa9b..2a0806324b 100644 --- a/lib/utils/test/src/utils/bidict/algorithms/filter_keys.cc +++ b/lib/utils/test/src/utils/bidict/algorithms/bidict_filter_keys.cc @@ -1,20 +1,21 @@ -#include "utils/bidict/algorithms/filter_keys.h" +#include "utils/bidict/algorithms/bidict_filter_keys.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("filter_keys(bidict, F)") { + TEST_CASE("bidict_filter_keys(bidict, F)") { bidict dict = { {1, "one"}, {2, "two"}, }; bidict result = - filter_keys(dict, [](int k) { return k == 1; }); + bidict_filter_keys(dict, [](int k) { return k == 1; }); bidict correct = { {1, "one"}, }; + CHECK(result == correct); } } diff --git a/lib/utils/test/src/utils/bidict/algorithms/filter_values.cc b/lib/utils/test/src/utils/bidict/algorithms/bidict_filter_values.cc similarity index 61% rename from lib/utils/test/src/utils/bidict/algorithms/filter_values.cc rename to lib/utils/test/src/utils/bidict/algorithms/bidict_filter_values.cc index 54d0bad199..5c67e11c4f 100644 --- a/lib/utils/test/src/utils/bidict/algorithms/filter_values.cc +++ b/lib/utils/test/src/utils/bidict/algorithms/bidict_filter_values.cc @@ -1,17 +1,17 @@ -#include "utils/bidict/algorithms/filter_values.h" +#include "utils/bidict/algorithms/bidict_filter_values.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("filter_values(bidict, F") { + TEST_CASE("bidict_filter_values(bidict, F") { bidict dict = { {1, "one"}, {2, "two"}, }; bidict result = - filter_values(dict, [](std::string const &v) { return v == "two"; }); + bidict_filter_values(dict, [](std::string const &v) { return v == "two"; }); bidict correct = { {2, "two"}, }; diff --git a/lib/utils/test/src/utils/bidict/algorithms/filtrans_keys.cc b/lib/utils/test/src/utils/bidict/algorithms/bidict_filtrans_keys.cc similarity index 75% rename from lib/utils/test/src/utils/bidict/algorithms/filtrans_keys.cc rename to lib/utils/test/src/utils/bidict/algorithms/bidict_filtrans_keys.cc index 300918f978..a2774440e5 100644 --- a/lib/utils/test/src/utils/bidict/algorithms/filtrans_keys.cc +++ b/lib/utils/test/src/utils/bidict/algorithms/bidict_filtrans_keys.cc @@ -1,17 +1,17 @@ -#include "utils/bidict/algorithms/filtrans_keys.h" +#include "utils/bidict/algorithms/bidict_filtrans_keys.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("filtrans_keys(bidict, F)") { + TEST_CASE("bidict_filtrans_keys") { bidict dict = { {1, "one"}, {2, "two"}, }; bidict result = - filtrans_keys(dict, [](int k) -> std::optional { + bidict_filtrans_keys(dict, [](int k) -> std::optional { if (k == 1) { return std::nullopt; } else { diff --git a/lib/utils/test/src/utils/bidict/algorithms/filtrans_values.cc b/lib/utils/test/src/utils/bidict/algorithms/bidict_filtrans_values.cc similarity index 69% rename from lib/utils/test/src/utils/bidict/algorithms/filtrans_values.cc rename to lib/utils/test/src/utils/bidict/algorithms/bidict_filtrans_values.cc index 99aaef114d..8687d539bc 100644 --- a/lib/utils/test/src/utils/bidict/algorithms/filtrans_values.cc +++ b/lib/utils/test/src/utils/bidict/algorithms/bidict_filtrans_values.cc @@ -1,26 +1,28 @@ -#include "utils/bidict/algorithms/filtrans_values.h" +#include "utils/bidict/algorithms/bidict_filtrans_values.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("filtrans_values(bidict, F)") { + TEST_CASE("bidict_filtrans_values") { bidict dict = { {1, "one"}, {2, "two"}, }; bidict result = - filtrans_values(dict, [](std::string const &v) -> std::optional { + bidict_filtrans_values(dict, [](std::string const &v) -> std::optional { if (v == "two") { return std::nullopt; } else { return v.size() + 1; } }); + bidict correct = { {1, 4}, }; + CHECK(result == correct); } } diff --git a/lib/utils/test/src/utils/bidict/algorithms/unordered_set_of.cc b/lib/utils/test/src/utils/bidict/algorithms/bidict_unordered_set_of.cc similarity index 77% rename from lib/utils/test/src/utils/bidict/algorithms/unordered_set_of.cc rename to lib/utils/test/src/utils/bidict/algorithms/bidict_unordered_set_of.cc index b88b6df0ca..d44a2fe62b 100644 --- a/lib/utils/test/src/utils/bidict/algorithms/unordered_set_of.cc +++ b/lib/utils/test/src/utils/bidict/algorithms/bidict_unordered_set_of.cc @@ -1,4 +1,4 @@ -#include "utils/bidict/algorithms/unordered_set_of.h" +#include "utils/bidict/algorithms/bidict_unordered_set_of.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/bidict/bidict.cc b/lib/utils/test/src/utils/bidict/bidict.cc index 1365d04027..f15f15b0fe 100644 --- a/lib/utils/test/src/utils/bidict/bidict.cc +++ b/lib/utils/test/src/utils/bidict/bidict.cc @@ -113,11 +113,13 @@ TEST_SUITE(FF_TEST_SUITE) { bidict deserialized = bidict{ {2, "hello"}, {3, "goodbye"}, + {4, "yes"}, }; nlohmann::json serialized = std::vector>{ {2, "hello"}, {3, "goodbye"}, + {4, "yes"}, }; SUBCASE("to_json") { diff --git a/lib/utils/test/src/utils/containers/enumerate.cc b/lib/utils/test/src/utils/containers/enumerate.cc index 2fdb2e481e..22bae1f613 100644 --- a/lib/utils/test/src/utils/containers/enumerate.cc +++ b/lib/utils/test/src/utils/containers/enumerate.cc @@ -4,7 +4,7 @@ #include "test/utils/doctest/fmt/unordered_multiset.h" #include "test/utils/doctest/fmt/unordered_set.h" #include "test/utils/doctest/fmt/vector.h" -#include "utils/containers/keys.h" +#include "utils/containers/unordered_keys.h" #include "utils/containers/unordered_multiset_of.h" #include "utils/containers/values.h" #include "utils/containers/vector_of.h" @@ -50,7 +50,7 @@ TEST_SUITE(FF_TEST_SUITE) { std::unordered_multiset correct_values = {"A", "B", "C", "D"}; std::map result = enumerate(input); - CHECK(keys(result) == correct_keys); + CHECK(unordered_keys(result) == correct_keys); CHECK(unordered_multiset_of(values(result)) == correct_values); } } diff --git a/lib/utils/test/src/utils/containers/keys.cc b/lib/utils/test/src/utils/containers/keys.cc index 5bdaef6d08..d2ac3dbfba 100644 --- a/lib/utils/test/src/utils/containers/keys.cc +++ b/lib/utils/test/src/utils/containers/keys.cc @@ -1,5 +1,5 @@ #include "utils/containers/keys.h" -#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/set.h" #include #include #include @@ -9,10 +9,10 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("keys") { - std::unordered_map m = { + std::map m = { {1, "one"}, {2, "two"}, {3, "three"}}; - std::unordered_set result = keys(m); - std::unordered_set expected = {1, 2, 3}; + std::set result = keys(m); + std::set expected = {1, 2, 3}; CHECK(result == expected); } } diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/get_imm_post_dominators_map.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/get_imm_post_dominators_map.cc index 4435ccc26c..e92a37169b 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/get_imm_post_dominators_map.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/get_imm_post_dominators_map.cc @@ -1,5 +1,4 @@ #include "utils/graph/digraph/algorithms/get_imm_post_dominators_map.h" -#include "utils/containers/generate_map.h" #include "utils/graph/algorithms.h" #include "utils/graph/instances/adjacency_digraph.h" #include diff --git a/lib/utils/test/src/utils/graph/series_parallel/sp_ization/work_duplicating_sp_ization.cc b/lib/utils/test/src/utils/graph/series_parallel/sp_ization/work_duplicating_sp_ization.cc index 8e276daca7..85c4d66f6d 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/sp_ization/work_duplicating_sp_ization.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/sp_ization/work_duplicating_sp_ization.cc @@ -1,6 +1,6 @@ #include "utils/graph/series_parallel/sp_ization/work_duplicating_sp_ization.h" #include "test/utils/rapidcheck.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/graph/algorithms.h" #include "utils/graph/digraph/algorithms/get_initial_nodes.h" #include "utils/graph/digraph/algorithms/get_terminal_nodes.h" @@ -46,7 +46,7 @@ static std::pair> } std::unordered_map cost_map = - generate_map(get_nodes(g), [](Node const &) { + generate_unordered_map(get_nodes(g), [](Node const &) { return static_cast(*rc::gen::inRange(1, 101)); }); diff --git a/lib/utils/test/src/utils/many_to_one/many_to_one.cc b/lib/utils/test/src/utils/many_to_one/many_to_one.cc index 13f88fab7c..ce219676e1 100644 --- a/lib/utils/test/src/utils/many_to_one/many_to_one.cc +++ b/lib/utils/test/src/utils/many_to_one/many_to_one.cc @@ -96,6 +96,37 @@ TEST_SUITE(FF_TEST_SUITE) { } } + TEST_CASE("adl_serializer>") { + ManyToOne deserialized = ManyToOne{ + {{2, 20}, {"two"}}, + {{3}, "three"}, + {{4, 40, 400}, "four"}, + }; + + nlohmann::json serialized = std::set>{ + {2, "two"}, + {3, "three"}, + {4, "four"}, + {20, "two"}, + {40, "four"}, + {400, "four"}, + }; + + SUBCASE("to_json") { + nlohmann::json result = deserialized; + nlohmann::json correct = serialized; + + CHECK(result == correct); + } + + SUBCASE("from_json") { + ManyToOne result = serialized; + ManyToOne correct = deserialized; + + CHECK(result == correct); + } + } + TEST_CASE("fmt::to_string(ManyToOne)") { ManyToOne input = ManyToOne{ {{1, 10, 100}, "one"}, @@ -107,4 +138,52 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(multiset_of(result) == multiset_of(correct)); } + + TEST_CASE("many_to_one_from_unstructured_relation") { + SUBCASE("relation is many-to-one") { + std::unordered_set> input = { + {1, "odd"}, + {2, "even"}, + {3, "odd"}, + }; + + ManyToOne result = + many_to_one_from_unstructured_relation(input); + ManyToOne correct = { + {{1, 3}, "odd"}, + {{2}, "even"}, + }; + + CHECK(result == correct); + } + + SUBCASE("relation is one-to-one") { + std::unordered_set> input = { + {1, "one"}, + {2, "two"}, + {3, "three"}, + }; + + ManyToOne result = + many_to_one_from_unstructured_relation(input); + ManyToOne correct = { + {{1}, "one"}, + {{2}, "two"}, + {{3}, "three"}, + }; + + CHECK(result == correct); + } + + SUBCASE("relation is not many-to-one") { + std::unordered_set> input = { + {1, "one"}, + {1, "ODD"}, + {2, "two"}, + {3, "ODD"}, + }; + + CHECK_THROWS(many_to_one_from_unstructured_relation(input)); + } + } } diff --git a/lib/utils/test/src/utils/many_to_one/many_to_one_from_unstructured_relation.cc b/lib/utils/test/src/utils/many_to_one/many_to_one_from_unstructured_relation.cc deleted file mode 100644 index c9d0e866d4..0000000000 --- a/lib/utils/test/src/utils/many_to_one/many_to_one_from_unstructured_relation.cc +++ /dev/null @@ -1,54 +0,0 @@ -#include "utils/many_to_one/many_to_one_from_unstructured_relation.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("many_to_one_from_unstructured_relation") { - SUBCASE("relation is many-to-one") { - std::unordered_set> input = { - {1, "odd"}, - {2, "even"}, - {3, "odd"}, - }; - - ManyToOne result = - many_to_one_from_unstructured_relation(input); - ManyToOne correct = { - {{1, 3}, "odd"}, - {{2}, "even"}, - }; - - CHECK(result == correct); - } - - SUBCASE("relation is one-to-one") { - std::unordered_set> input = { - {1, "one"}, - {2, "two"}, - {3, "three"}, - }; - - ManyToOne result = - many_to_one_from_unstructured_relation(input); - ManyToOne correct = { - {{1}, "one"}, - {{2}, "two"}, - {{3}, "three"}, - }; - - CHECK(result == correct); - } - - SUBCASE("relation is not many-to-one") { - std::unordered_set> input = { - {1, "one"}, - {1, "ODD"}, - {2, "two"}, - {3, "ODD"}, - }; - - CHECK_THROWS(many_to_one_from_unstructured_relation(input)); - } - } -} diff --git a/lib/utils/test/src/utils/many_to_one/unstructured_relation_from_many_to_one.cc b/lib/utils/test/src/utils/many_to_one/unstructured_relation_from_many_to_one.cc deleted file mode 100644 index 15a8f5b390..0000000000 --- a/lib/utils/test/src/utils/many_to_one/unstructured_relation_from_many_to_one.cc +++ /dev/null @@ -1,25 +0,0 @@ -#include "utils/many_to_one/unstructured_relation_from_many_to_one.h" -#include "test/utils/doctest/fmt/pair.h" -#include "test/utils/doctest/fmt/unordered_set.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("unstructured_relation_from_many_to_one") { - ManyToOne input = { - {{1, 3}, "odd"}, - {{2}, "even"}, - }; - - std::unordered_set> result = - unstructured_relation_from_many_to_one(input); - std::unordered_set> correct = { - {1, "odd"}, - {2, "even"}, - {3, "odd"}, - }; - - CHECK(result == correct); - } -} diff --git a/lib/utils/test/src/utils/one_to_many/one_to_many.cc b/lib/utils/test/src/utils/one_to_many/one_to_many.cc index d2ea7d6a0b..de149ce609 100644 --- a/lib/utils/test/src/utils/one_to_many/one_to_many.cc +++ b/lib/utils/test/src/utils/one_to_many/one_to_many.cc @@ -1,8 +1,11 @@ #include "utils/one_to_many/one_to_many.h" #include "test/utils/doctest/fmt/multiset.h" #include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/set.h" #include "utils/containers/multiset_of.h" #include "utils/one_to_many/one_to_many_from_l_to_r_mapping.h" +#include "test/utils/doctest/fmt/pair.h" +#include "test/utils/doctest/fmt/unordered_set.h" #include using namespace ::FlexFlow; @@ -31,9 +34,9 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("at_l") { - nonempty_unordered_set result = m.at_l(1); + nonempty_set result = m.at_l(1); - nonempty_unordered_set correct = {"one", "One", "ONE"}; + nonempty_set correct = {"one", "One", "ONE"}; CHECK(result == correct); } @@ -47,17 +50,17 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("left_values") { - std::unordered_set result = m.left_values(); + std::set result = m.left_values(); - std::unordered_set correct = {1, 2}; + std::set correct = {1, 2}; CHECK(result == correct); } SUBCASE("right_values") { - std::unordered_set result = m.right_values(); + std::set result = m.right_values(); - std::unordered_set correct = {"one", "One", "ONE", "two"}; + std::set correct = {"one", "One", "ONE", "two"}; CHECK(result == correct); } @@ -90,6 +93,35 @@ TEST_SUITE(FF_TEST_SUITE) { } } + TEST_CASE("adl_serializer>") { + OneToMany deserialized = OneToMany{ + {2, {"two", "TWO"}}, + {3, {"three"}}, + {4, {"four"}}, + }; + + nlohmann::json serialized = std::set>{ + {2, "two"}, + {2, "TWO"}, + {3, "three"}, + {4, "four"}, + }; + + SUBCASE("to_json") { + nlohmann::json result = deserialized; + nlohmann::json correct = serialized; + + CHECK(result == correct); + } + + SUBCASE("from_json") { + OneToMany result = serialized; + OneToMany correct = deserialized; + + CHECK(result == correct); + } + } + TEST_CASE("fmt::to_string(OneToMany)") { OneToMany input = one_to_many_from_l_to_r_mapping( @@ -100,4 +132,67 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(multiset_of(result) == multiset_of(correct)); } + + TEST_CASE("unstructured_relation_from_one_to_many") { + OneToMany input = { + {1, {"one", "ONE"}}, + {2, {"two"}}, + }; + + std::unordered_set> result = + unstructured_relation_from_one_to_many(input); + std::unordered_set> correct = { + {1, "one"}, + {1, "ONE"}, + {2, "two"}, + }; + + CHECK(result == correct); + } + + TEST_CASE("one_to_many_from_unstructured_relation") { + SUBCASE("relation is one-to-many") { + std::unordered_set> input = { + {1, "one"}, + {1, "ONE"}, + {2, "two"}, + }; + + OneToMany result = + one_to_many_from_unstructured_relation(input); + OneToMany correct = { + {1, {"one", "ONE"}}, + {2, {"two"}}, + }; + + CHECK(result == correct); + } + + SUBCASE("relation is one-to-one") { + std::unordered_set> input = { + {1, "one"}, + {2, "two"}, + }; + + OneToMany result = + one_to_many_from_unstructured_relation(input); + OneToMany correct = { + {1, {"one"}}, + {2, {"two"}}, + }; + + CHECK(result == correct); + } + + SUBCASE("relation is not one-to-many") { + std::unordered_set> input = { + {1, "one"}, + {1, "ONE"}, + {2, "two"}, + {3, "ONE"}, + }; + + CHECK_THROWS(one_to_many_from_unstructured_relation(input)); + } + } } diff --git a/lib/utils/test/src/utils/one_to_many/one_to_many_from_unstructured_relation.cc b/lib/utils/test/src/utils/one_to_many/one_to_many_from_unstructured_relation.cc deleted file mode 100644 index 023c556c46..0000000000 --- a/lib/utils/test/src/utils/one_to_many/one_to_many_from_unstructured_relation.cc +++ /dev/null @@ -1,53 +0,0 @@ -#include "utils/one_to_many/one_to_many_from_unstructured_relation.h" -#include -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("one_to_many_from_unstructured_relation") { - SUBCASE("relation is one-to-many") { - std::unordered_set> input = { - {1, "one"}, - {1, "ONE"}, - {2, "two"}, - }; - - OneToMany result = - one_to_many_from_unstructured_relation(input); - OneToMany correct = { - {1, {"one", "ONE"}}, - {2, {"two"}}, - }; - - CHECK(result == correct); - } - - SUBCASE("relation is one-to-one") { - std::unordered_set> input = { - {1, "one"}, - {2, "two"}, - }; - - OneToMany result = - one_to_many_from_unstructured_relation(input); - OneToMany correct = { - {1, {"one"}}, - {2, {"two"}}, - }; - - CHECK(result == correct); - } - - SUBCASE("relation is not one-to-many") { - std::unordered_set> input = { - {1, "one"}, - {1, "ONE"}, - {2, "two"}, - {3, "ONE"}, - }; - - CHECK_THROWS(one_to_many_from_unstructured_relation(input)); - } - } -} diff --git a/lib/utils/test/src/utils/one_to_many/unstructured_relation_from_one_to_many.cc b/lib/utils/test/src/utils/one_to_many/unstructured_relation_from_one_to_many.cc deleted file mode 100644 index 06a5d5ff2e..0000000000 --- a/lib/utils/test/src/utils/one_to_many/unstructured_relation_from_one_to_many.cc +++ /dev/null @@ -1,25 +0,0 @@ -#include "utils/one_to_many/unstructured_relation_from_one_to_many.h" -#include "test/utils/doctest/fmt/pair.h" -#include "test/utils/doctest/fmt/unordered_set.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("unstructured_relation_from_one_to_many") { - OneToMany input = { - {1, {"one", "ONE"}}, - {2, {"two"}}, - }; - - std::unordered_set> result = - unstructured_relation_from_one_to_many(input); - std::unordered_set> correct = { - {1, "one"}, - {1, "ONE"}, - {2, "two"}, - }; - - CHECK(result == correct); - } -} From 164004d4da20c4c262bd44f9baa471d63cf8d8e0 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Thu, 28 May 2026 04:00:09 -0700 Subject: [PATCH 16/19] Pass replicate copy insertion test case --- .../src/task-spec/dynamic_graph/copy_insertion.cc | 7 +++++++ .../test/src/task-spec/dynamic_graph/copy_insertion.cc | 7 ++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc b/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc index 08ab3b11aa..f24dd27da2 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc @@ -18,6 +18,7 @@ #include "utils/containers/set_difference.h" #include "utils/containers/transform.h" #include "utils/optional.h" +#include "task-spec/dynamic_graph/training_operation_attrs.h" namespace FlexFlow { @@ -117,6 +118,12 @@ std::unordered_set copies_for_invocation_inputs( DynamicNodeInvocation const &i, std::unordered_map const &unmapped_value_to_src_mapped_value) { + if (training_op_attrs_has_op_type(assert_unwrap(i.node_attrs.op_attrs), OperatorType::REPLICATE)) { + // copies should not be inserted before a replicate, as the replicate + // implicitly includes the copy operations + return {}; + } + MappedOperatorTaskGroup mapping = assert_unwrap(i.node_attrs.mapping); auto map_tensor = [&](DynamicTensorSlot const &slot, diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/copy_insertion.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/copy_insertion.cc index fdc705dc54..31de844555 100644 --- a/lib/task-spec/test/src/task-spec/dynamic_graph/copy_insertion.cc +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/copy_insertion.cc @@ -8,6 +8,7 @@ #include #include "task-spec/dynamic_graph/dynamic_value_attrs.h" #include "task-spec/dynamic_graph/serializable_dynamic_node_invocation.h" +#include "op-attrs/ops/element_unary.h" using namespace ::FlexFlow; @@ -250,7 +251,11 @@ TEST_SUITE(FF_TEST_SUITE) { /*task_type=*/DynamicTaskType::FWD, /*device_coord=*/std::nullopt, /*mapping=*/invocation_mapping, - /*op_attrs=*/std::nullopt, + /*op_attrs=*/TrainingOperationAttrs{ + PCGOperatorAttrs{ + make_relu_attrs(), + }, + }, /*layer_guid=*/ dynamic_layer_guid_t{parallel_layer_guid_t{Node{invocation_id}}}, /*per_device_op_state=*/std::nullopt, From 032f6ba992d6e3f157e3115e241762a6ad17b3df Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Fri, 29 May 2026 02:21:15 -0700 Subject: [PATCH 17/19] Pass shard expansion test for fwd replicate --- .../abstracted_single_tensor_movement.cc | 8 +- .../abstracted_tensor_set_movement.cc | 4 +- ...racted_tensor_set_movement_across_split.cc | 4 +- .../machine_mapping/machine_mapping.cc | 4 +- .../machine_mapping_constraints.cc | 2 +- ...el_layer_guid_oblivious_machine_mapping.cc | 4 +- lib/kernels/include/kernels/accessor.h | 10 + lib/kernels/src/kernels/accessor.cc | 32 ++++ ...space_to_parallel_tensor_space_mappings.cc | 4 +- .../op-attrs/parallel_tensor_dim_degrees.cc | 4 +- lib/pcg/src/pcg/computation_graph.cc | 4 +- lib/pcg/src/pcg/computation_graph_builder.cc | 3 +- .../parallel_computation_graph.cc | 3 +- .../parallel_computation_graph_builder.cc | 3 +- .../apply_substitution/apply_substitution.cc | 8 +- .../perform_shape_inference.cc | 4 +- .../output_operator_attrs_assignment.cc | 4 +- .../include/task-spec/device_specific.h | 16 ++ ...vice_specific_per_device_op_state.dtg.toml | 1 + .../dynamic_layer_guid_t.dtg.toml | 1 + .../dynamic_graph/dynamic_node_attrs.dtg.toml | 8 +- .../dynamic_node_invocation.dtg.toml | 9 +- ...mic_node_invocation_sharding_info.dtg.toml | 6 +- .../dynamic_tensor_accessor.dtg.toml | 1 + .../dynamic_tensor_guid_t.dtg.toml | 1 + .../dynamic_tensor_slot.dtg.toml | 11 ++ .../dynamic_value_attrs.dtg.toml | 1 + .../serializable_dynamic_node_attrs.dtg.toml | 6 +- ...ializable_dynamic_node_invocation.dtg.toml | 9 +- .../serializable_dynamic_value_attrs.dtg.toml | 1 + .../training_operation_attrs.dtg.toml | 1 + .../task-spec/dynamic_graph/copy_insertion.cc | 14 +- .../dynamic_open_dataflow_graph.cc | 8 +- .../task-spec/dynamic_graph/loss_insertion.cc | 37 +++- .../dynamic_graph/machine_slicing.cc | 4 +- ...ake_dynamic_open_dataflow_graph_from_cg.cc | 6 +- ...mic_open_dataflow_graph_from_mapped_pcg.cc | 5 +- .../serializable_dynamic_node_attrs.cc | 4 +- .../dynamic_graph/shard_expansion.cc | 154 ++++++++++++++-- .../dynamic_graph/update_insertion.cc | 2 +- .../task-spec/dynamic_graph/copy_insertion.cc | 1 + .../dynamic_open_dataflow_graph.cc | 42 +++-- .../dynamic_graph/machine_slicing.cc | 7 +- ...mic_open_dataflow_graph_from_mapped_pcg.cc | 8 +- .../task-spec/dynamic_graph/pass_expansion.cc | 55 +++--- .../dynamic_graph/shard_expansion.cc | 173 +++++++++++++++++- .../containers/binary_merge_disjoint_maps.h | 14 +- .../binary_merge_disjoint_unordered_maps.h | 28 +++ .../utils/containers/binary_merge_maps_with.h | 26 +-- .../binary_merge_maps_with_left_dominating.h | 6 +- .../binary_merge_maps_with_right_dominating.h | 6 +- .../binary_merge_unordered_maps_with.h | 42 +++++ ...erge_unordered_maps_with_left_dominating.h | 19 ++ ...rge_unordered_maps_with_right_dominating.h | 19 ++ lib/utils/include/utils/containers/flatmap.h | 4 +- lib/utils/include/utils/containers/get_only.h | 8 +- .../include/utils/containers/map_from_pairs.h | 15 +- lib/utils/include/utils/containers/map_keys.h | 27 ++- .../include/utils/containers/map_values2.h | 15 ++ .../utils/containers/merge_disjoint_maps.h | 8 +- .../merge_disjoint_unordered_maps.h | 24 +++ .../include/utils/containers/merge_in_map.h | 6 +- .../utils/containers/merge_in_unordered_map.h | 23 +++ .../utils/containers/merge_maps_with.h | 10 +- .../merge_maps_with_right_dominating.h | 6 +- .../containers/merge_unordered_maps_with.h | 25 +++ ...rge_unordered_maps_with_right_dominating.h | 23 +++ .../include/utils/containers/restrict_keys.h | 13 ++ .../utils/containers/zip_values_strict.h | 18 ++ .../full_binary_tree/get_path_to_leaf_map.h | 4 +- .../include/utils/nonempty_set/nonempty_set.h | 34 +++- .../require_one_to_many_is_bijection.h | 23 +++ .../utils/orthotope/minimal_dim_domain.h | 4 +- .../containers/binary_merge_disjoint_maps.cc | 9 +- .../binary_merge_disjoint_unordered_maps.cc | 13 ++ .../containers/binary_merge_maps_with.cc | 7 +- .../binary_merge_maps_with_left_dominating.cc | 9 +- ...binary_merge_maps_with_right_dominating.cc | 9 +- .../binary_merge_unordered_maps_with.cc | 13 ++ ...rge_unordered_maps_with_left_dominating.cc | 13 ++ ...ge_unordered_maps_with_right_dominating.cc | 13 ++ .../src/utils/containers/map_from_pairs.cc | 17 +- lib/utils/src/utils/containers/map_keys.cc | 21 +++ lib/utils/src/utils/containers/map_values2.cc | 19 +- .../utils/containers/merge_disjoint_maps.cc | 7 +- .../merge_disjoint_unordered_maps.cc | 12 ++ .../src/utils/containers/merge_in_map.cc | 10 +- .../containers/merge_in_unordered_map.cc | 12 ++ .../src/utils/containers/merge_maps_with.cc | 7 +- .../merge_maps_with_right_dominating.cc | 7 +- .../containers/merge_unordered_maps_with.cc | 14 ++ ...ge_unordered_maps_with_right_dominating.cc | 12 ++ .../src/utils/containers/restrict_keys.cc | 20 ++ .../src/utils/containers/zip_values_strict.cc | 21 ++- .../src/utils/nonempty_set/nonempty_set.cc | 8 + .../require_one_to_many_is_bijection.cc | 12 ++ 96 files changed, 1191 insertions(+), 241 deletions(-) create mode 100644 lib/utils/include/utils/containers/binary_merge_disjoint_unordered_maps.h create mode 100644 lib/utils/include/utils/containers/binary_merge_unordered_maps_with.h create mode 100644 lib/utils/include/utils/containers/binary_merge_unordered_maps_with_left_dominating.h create mode 100644 lib/utils/include/utils/containers/binary_merge_unordered_maps_with_right_dominating.h create mode 100644 lib/utils/include/utils/containers/merge_disjoint_unordered_maps.h create mode 100644 lib/utils/include/utils/containers/merge_in_unordered_map.h create mode 100644 lib/utils/include/utils/containers/merge_unordered_maps_with.h create mode 100644 lib/utils/include/utils/containers/merge_unordered_maps_with_right_dominating.h create mode 100644 lib/utils/include/utils/one_to_many/require_one_to_many_is_bijection.h create mode 100644 lib/utils/src/utils/containers/binary_merge_disjoint_unordered_maps.cc create mode 100644 lib/utils/src/utils/containers/binary_merge_unordered_maps_with.cc create mode 100644 lib/utils/src/utils/containers/binary_merge_unordered_maps_with_left_dominating.cc create mode 100644 lib/utils/src/utils/containers/binary_merge_unordered_maps_with_right_dominating.cc create mode 100644 lib/utils/src/utils/containers/merge_disjoint_unordered_maps.cc create mode 100644 lib/utils/src/utils/containers/merge_in_unordered_map.cc create mode 100644 lib/utils/src/utils/containers/merge_unordered_maps_with.cc create mode 100644 lib/utils/src/utils/containers/merge_unordered_maps_with_right_dominating.cc create mode 100644 lib/utils/src/utils/one_to_many/require_one_to_many_is_bijection.cc diff --git a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.cc b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.cc index 5f9300973f..e8bd602289 100644 --- a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.cc +++ b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.cc @@ -1,13 +1,13 @@ #include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.h" #include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_communication_edge.h" #include "utils/containers/filtermap_keys.h" -#include "utils/containers/map_from_pairs.h" #include "utils/containers/map_keys_with_value_merging.h" -#include "utils/containers/merge_maps_with.h" #include "utils/containers/require_all_same1.h" #include "utils/containers/require_same.h" #include "utils/containers/transform.h" #include "utils/containers/values.h" +#include "utils/containers/merge_unordered_maps_with.h" +#include "utils/containers/unordered_map_from_pairs.h" namespace FlexFlow { @@ -34,7 +34,7 @@ AbstractedSingleTensorMovement merge_abstracted_single_tensor_movements( return AbstractedSingleTensorMovement{ /*src_op_tree_path=*/require_all_same1(src_paths), /*edge_to_size=*/ - merge_maps_with(transform(vector_of(movements), + merge_unordered_maps_with(transform(vector_of(movements), [](AbstractedSingleTensorMovement const &m) { return m.edge_to_size; }), @@ -51,7 +51,7 @@ AbstractedSingleTensorMovement return AbstractedSingleTensorMovement{ /*src_op_tree_path=*/src_op_tree_path, /*edge_to_size=*/ - map_from_pairs( + unordered_map_from_pairs( transform(communications, [](AbstractedSingleTensorCommunication const &c) { return std::pair{c.edge, c.size}; diff --git a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.cc b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.cc index 98a7d9b0b2..37bf62029f 100644 --- a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.cc +++ b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.cc @@ -3,13 +3,13 @@ #include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.dtg.h" #include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.h" #include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h" -#include "utils/containers/binary_merge_maps_with.h" #include "utils/containers/flatmap.h" #include "utils/containers/map_keys_with_value_merging.h" #include "utils/containers/merge_maps_with.h" #include "utils/containers/transform.h" #include "utils/containers/unordered_set_of.h" #include "utils/hash/unordered_map.h" +#include "utils/containers/binary_merge_unordered_maps_with.h" namespace FlexFlow { @@ -63,7 +63,7 @@ TensorSetMovement concretize_abstracted_tensor_set_movement( [](TensorSetMovement const &lhs, TensorSetMovement const &rhs) -> TensorSetMovement { return TensorSetMovement{ - binary_merge_maps_with( + binary_merge_unordered_maps_with( lhs.edge_to_size, rhs.edge_to_size, [](num_bytes_t l, num_bytes_t r) { return l + r; }), diff --git a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc index 6ff261facd..192ada3fb6 100644 --- a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc +++ b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc @@ -44,7 +44,7 @@ AbstractedSingleTensorMovement get_abstracted_single_tensor_movement_along_edge( op_to_op_get_coord_mapping(mapping); std::unordered_map - single_comms = map_from_pairs(transform( + single_comms = unordered_map_from_pairs(transform( bidict_unordered_set_of(coord_mapping), [&](std::pair const & src_dst) -> std::pair const &edges) + [&](nonempty_set const &edges) { return merge_abstracted_single_tensor_movements(transform( unordered_multiset_of(edges.unwrap_as_unordered_set()), diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc index 861912efef..c7b068d121 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc @@ -7,8 +7,8 @@ #include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h" #include "utils/bidict/algorithms/bidict_from_map.h" #include "utils/containers/are_disjoint.h" -#include "utils/containers/binary_merge_disjoint_maps.h" #include "utils/containers/unordered_keys.h" +#include "utils/containers/binary_merge_disjoint_unordered_maps.h" namespace FlexFlow { @@ -49,7 +49,7 @@ MappedParallelComputationGraph MachineMapping combine_disjoint_mappings(MachineMapping const &m1, MachineMapping const &m2) { return MachineMapping{ - binary_merge_disjoint_maps(m1.machine_views, m2.machine_views), + binary_merge_disjoint_unordered_maps(m1.machine_views, m2.machine_views), }; } diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc index fe92c77def..f77d424795 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc @@ -4,7 +4,7 @@ #include "utils/containers/filtermap_keys.h" #include "utils/containers/flatmap.h" #include "utils/containers/generate_unordered_map.h" -#include "utils/containers/keys.h" +#include "utils/containers/unordered_keys.h" #include "utils/containers/map_values.h" #include "utils/containers/restrict_keys.h" #include "utils/full_binary_tree/binary_tree_path.h" diff --git a/lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc index ac39021f6f..97814288ab 100644 --- a/lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc @@ -5,11 +5,11 @@ #include "op-attrs/get_operator_task_space.h" #include "op-attrs/parallel_tensor_shape.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" -#include "utils/containers/binary_merge_disjoint_maps.h" #include "utils/containers/map_keys.h" #include "utils/containers/require_same.h" #include "utils/containers/try_at.h" #include "utils/full_binary_tree/binary_tree_path.h" +#include "utils/containers/binary_merge_disjoint_unordered_maps.h" namespace FlexFlow { @@ -17,7 +17,7 @@ ParallelLayerGuidObliviousMachineMapping binary_combine_mappings( ParallelLayerGuidObliviousMachineMapping const &lhs, ParallelLayerGuidObliviousMachineMapping const &rhs) { return ParallelLayerGuidObliviousMachineMapping{ - binary_merge_disjoint_maps( + binary_merge_disjoint_unordered_maps( map_keys(lhs.raw_mapping, nest_inside_left_child), map_keys(rhs.raw_mapping, nest_inside_right_child)), }; diff --git a/lib/kernels/include/kernels/accessor.h b/lib/kernels/include/kernels/accessor.h index 27d3693e7e..cacc56046f 100644 --- a/lib/kernels/include/kernels/accessor.h +++ b/lib/kernels/include/kernels/accessor.h @@ -43,6 +43,11 @@ class GenericTensorAccessorR { bool operator==(GenericTensorAccessorR const &) const; bool operator!=(GenericTensorAccessorR const &) const; + bool operator<(GenericTensorAccessorR const &) const; + bool operator<=(GenericTensorAccessorR const &) const; + bool operator>(GenericTensorAccessorR const &) const; + bool operator>=(GenericTensorAccessorR const &) const; + template real_type_t
const &at(TensorDimsCoord const &indices) const { ASSERT(this->device_type == DeviceType::CPU, @@ -97,6 +102,11 @@ class GenericTensorAccessorW { bool operator==(GenericTensorAccessorW const &) const; bool operator!=(GenericTensorAccessorW const &) const; + bool operator<(GenericTensorAccessorW const &) const; + bool operator<=(GenericTensorAccessorW const &) const; + bool operator>(GenericTensorAccessorW const &) const; + bool operator>=(GenericTensorAccessorW const &) const; + operator GenericTensorAccessorR() const; template diff --git a/lib/kernels/src/kernels/accessor.cc b/lib/kernels/src/kernels/accessor.cc index a3f8ead17f..75f144a57a 100644 --- a/lib/kernels/src/kernels/accessor.cc +++ b/lib/kernels/src/kernels/accessor.cc @@ -98,6 +98,22 @@ bool GenericTensorAccessorW::operator!=( return this->tie() != other.tie(); } +bool GenericTensorAccessorW::operator<(GenericTensorAccessorW const &other) const { + return this->tie() < other.tie(); +} + +bool GenericTensorAccessorW::operator<=(GenericTensorAccessorW const &other) const { + return this->tie() <= other.tie(); +} + +bool GenericTensorAccessorW::operator>(GenericTensorAccessorW const &other) const { + return this->tie() > other.tie(); +} + +bool GenericTensorAccessorW::operator>=(GenericTensorAccessorW const &other) const { + return this->tie() >= other.tie(); +} + int32_t *GenericTensorAccessorW::get_int32_ptr() const { return this->get(); } @@ -150,6 +166,22 @@ bool GenericTensorAccessorR::operator!=( return this->tie() != other.tie(); } +bool GenericTensorAccessorR::operator<(GenericTensorAccessorR const &other) const { + return this->tie() < other.tie(); +} + +bool GenericTensorAccessorR::operator<=(GenericTensorAccessorR const &other) const { + return this->tie() <= other.tie(); +} + +bool GenericTensorAccessorR::operator>(GenericTensorAccessorR const &other) const { + return this->tie() > other.tie(); +} + +bool GenericTensorAccessorR::operator>=(GenericTensorAccessorR const &other) const { + return this->tie() >= other.tie(); +} + int32_t const *GenericTensorAccessorR::get_int32_ptr() const { return this->get(); } diff --git a/lib/op-attrs/src/op-attrs/get_operator_space_to_parallel_tensor_space_mappings.cc b/lib/op-attrs/src/op-attrs/get_operator_space_to_parallel_tensor_space_mappings.cc index 618eb533ff..1a97f8b38b 100644 --- a/lib/op-attrs/src/op-attrs/get_operator_space_to_parallel_tensor_space_mappings.cc +++ b/lib/op-attrs/src/op-attrs/get_operator_space_to_parallel_tensor_space_mappings.cc @@ -8,11 +8,11 @@ #include "op-attrs/ops/weight.h" #include "utils/containers/filtrans.h" #include "utils/containers/get_only.h" -#include "utils/containers/merge_disjoint_maps.h" #include "utils/containers/require_only_key.h" #include "utils/containers/require_two_keys.h" #include "utils/containers/zip_values_strict.h" #include "utils/overload.h" +#include "utils/containers/merge_disjoint_unordered_maps.h" namespace FlexFlow { @@ -281,7 +281,7 @@ std::unordered_map ComputationGraphOpAttrs const &attrs, std::unordered_map const &inputs_degrees) { - return merge_disjoint_maps(std::vector{ + return merge_disjoint_unordered_maps(std::vector{ get_operator_to_input_mappings(attrs, inputs_degrees), get_operator_to_weight_mappings(attrs, inputs_degrees), get_operator_to_output_mappings(attrs, inputs_degrees), diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dim_degrees.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dim_degrees.cc index 83a7aded6a..5b5d0b514f 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dim_degrees.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dim_degrees.cc @@ -5,7 +5,6 @@ #include "op-attrs/parallel_tensor_dim_idx_t.dtg.h" #include "op-attrs/parallel_tensor_dim_idx_t.h" #include "op-attrs/parallel_tensor_space_coordinate.h" -#include "utils/containers/binary_merge_disjoint_maps.h" #include "utils/containers/filtermap_keys.h" #include "utils/containers/filtrans.h" #include "utils/containers/generate_unordered_map.h" @@ -19,6 +18,7 @@ #include "utils/nonnegative_int/nonnegative_range.h" #include "utils/nonnegative_int/num_elements.h" #include "utils/orthotope/minimal_dim_domain.h" +#include "utils/containers/binary_merge_disjoint_unordered_maps.h" namespace FlexFlow { @@ -100,7 +100,7 @@ std::unordered_map return degrees.shard_degrees.at(dim); }); - return binary_merge_disjoint_maps( + return binary_merge_disjoint_unordered_maps( /*lhs=*/replica_dim_degrees, /*rhs=*/map_keys(shard_dim_degrees, [](ff_dim_t const &dim) { return parallel_tensor_dim_idx_t{dim}; diff --git a/lib/pcg/src/pcg/computation_graph.cc b/lib/pcg/src/pcg/computation_graph.cc index 35ba0747f0..1b0b3b3204 100644 --- a/lib/pcg/src/pcg/computation_graph.cc +++ b/lib/pcg/src/pcg/computation_graph.cc @@ -2,7 +2,6 @@ #include "op-attrs/computation_graph_op_attrs.h" #include "op-attrs/get_incoming_tensor_roles.h" #include "op-attrs/shape_inference.h" -#include "utils/containers/binary_merge_disjoint_maps.h" #include "utils/containers/concat_vectors.h" #include "utils/containers/filter_values.h" #include "utils/containers/filtrans.h" @@ -35,6 +34,7 @@ #include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_view_as_dot.h" #include "utils/graph/node/algorithms.h" #include "utils/record_formatter.h" +#include "utils/containers/binary_merge_disjoint_unordered_maps.h" namespace FlexFlow { @@ -101,7 +101,7 @@ LayerAddedResult add_layer( KwargNodeAddedResult added = computation_graph.raw_graph.add_node( layer_attrs, - binary_merge_disjoint_maps(raw_inputs, raw_weights), + binary_merge_disjoint_unordered_maps(raw_inputs, raw_weights), output_attrs); return LayerAddedResult{ diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index 40e72aee9d..2eb140fa58 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -48,6 +48,7 @@ #include "utils/fmt/set.h" #include "utils/stack_vector/stack_vector_of.h" #include +#include "utils/containers/binary_merge_disjoint_unordered_maps.h" namespace FlexFlow { @@ -114,7 +115,7 @@ static void check_incoming_tensor_roles( restrict_keys(get_incoming_tensor_roles(layer.op_attrs), set_union(input_slots, weight_slots)); std::unordered_map current = - binary_merge_disjoint_maps( + binary_merge_disjoint_unordered_maps( generate_unordered_map( input_slots, [](TensorSlotName) { return IncomingTensorRole::INPUT; }), diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index c4a429a820..40e8cb9e5e 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -38,6 +38,7 @@ #include "utils/record_formatter.h" #include #include "utils/containers/map_from_unordered.h" +#include "utils/containers/binary_merge_disjoint_unordered_maps.h" namespace FlexFlow { @@ -112,7 +113,7 @@ ParallelLayerAddedResult add_parallel_layer( KwargNodeAddedResult op_added = pcg.raw_graph.add_node( layer_attrs, - binary_merge_disjoint_maps(unwrapped_inputs, unwrapped_weights), + binary_merge_disjoint_unordered_maps(unwrapped_inputs, unwrapped_weights), output_attrs); return ParallelLayerAddedResult{ diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index 92334cfde9..c01b17c285 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -34,6 +34,7 @@ #include "utils/containers/transform.h" #include "utils/containers/zip_values_strict_with.h" #include "utils/containers/zip_with.h" +#include "utils/containers/binary_merge_disjoint_unordered_maps.h" namespace FlexFlow { @@ -678,7 +679,7 @@ static void check_incoming_tensor_roles( std::unordered_map correct = get_incoming_tensor_roles(layer.op_attrs); std::unordered_map current = - binary_merge_disjoint_maps( + binary_merge_disjoint_unordered_maps( generate_unordered_map( input_slots, [](TensorSlotName) { return IncomingTensorRole::INPUT; }), diff --git a/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc b/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc index f3ceda7a06..b8140440b7 100644 --- a/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc +++ b/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc @@ -9,11 +9,11 @@ #include "substitutions/sub_parallel_computation_graph_data.dtg.h" #include "substitutions/sub_parallel_computation_graph_data.h" #include "substitutions/sub_parallel_computation_graph_edge.h" -#include "utils/containers/binary_merge_disjoint_maps.h" #include "utils/containers/unordered_keys.h" #include "utils/containers/restrict_keys.h" #include "utils/containers/set_minus.h" #include "utils/containers/values.h" +#include "utils/containers/binary_merge_disjoint_unordered_maps.h" namespace FlexFlow { @@ -64,7 +64,7 @@ SubParallelComputationGraph apply_substitution_from_output_result( std::unordered_map post_node_data_from_sub = output_graph_data.node_data; - return binary_merge_disjoint_maps(post_node_data_from_orig, + return binary_merge_disjoint_unordered_maps(post_node_data_from_orig, post_node_data_from_sub); }(); @@ -168,8 +168,8 @@ SubParallelComputationGraph apply_substitution_from_output_result( std::unordered_map post_value_data_from_sub = output_graph_data.value_data; - return binary_merge_disjoint_maps(post_value_data_from_orig, - post_value_data_from_sub); + return binary_merge_disjoint_unordered_maps(post_value_data_from_orig, + post_value_data_from_sub); }(); SubParallelComputationGraphData post_data = SubParallelComputationGraphData{ diff --git a/lib/substitutions/src/substitutions/apply_substitution/perform_shape_inference.cc b/lib/substitutions/src/substitutions/apply_substitution/perform_shape_inference.cc index 9ae007ef16..d3ad4ca246 100644 --- a/lib/substitutions/src/substitutions/apply_substitution/perform_shape_inference.cc +++ b/lib/substitutions/src/substitutions/apply_substitution/perform_shape_inference.cc @@ -1,7 +1,6 @@ #include "substitutions/apply_substitution/perform_shape_inference.h" #include "op-attrs/get_incoming_tensor_roles.h" #include "op-attrs/shape_inference.h" -#include "utils/containers/binary_merge_disjoint_maps.h" #include "utils/containers/filter_values.h" #include "utils/containers/filtrans.h" #include "utils/containers/is_subseteq_of.h" @@ -20,6 +19,7 @@ #include "utils/graph/open_dataflow_graph/algorithms/get_inputs.h" #include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_incoming_open_kwarg_dataflow_values_for_node.h" #include "utils/nonnegative_int/num_elements.h" +#include "utils/containers/binary_merge_disjoint_unordered_maps.h" namespace FlexFlow { @@ -72,7 +72,7 @@ LabelledOpenKwargDataflowGraphView weight_shapes = incoming_shapes_with_role(IncomingTensorRole::WEIGHT); - ASSERT(binary_merge_disjoint_maps(input_shapes, weight_shapes) == + ASSERT(binary_merge_disjoint_unordered_maps(input_shapes, weight_shapes) == incoming_shapes); std::unordered_map diff --git a/lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.cc b/lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.cc index 647362ee4d..2755716e44 100644 --- a/lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.cc +++ b/lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.cc @@ -2,9 +2,9 @@ #include "substitutions/operator_pattern/get_attribute_map.h" #include "substitutions/output_graph/materialize_operator_from_attrs_map.h" #include "substitutions/output_graph/output_operator_attribute_expr.h" -#include "utils/containers/binary_merge_maps_with_right_dominating.h" #include "utils/containers/map_values.h" #include "utils/exception.h" +#include "utils/containers/binary_merge_unordered_maps_with_right_dominating.h" namespace FlexFlow { @@ -36,7 +36,7 @@ PCGOperatorAttrs materialize_output_operator_from_attrs_assignment( }); std::unordered_map - joined_attrs_map = binary_merge_maps_with_right_dominating( + joined_attrs_map = binary_merge_unordered_maps_with_right_dominating( template_attrs_map, assignments_attrs_map); return materialize_operator_from_attrs_map(joined_attrs_map); diff --git a/lib/task-spec/include/task-spec/device_specific.h b/lib/task-spec/include/task-spec/device_specific.h index 2055888b1b..834378f415 100644 --- a/lib/task-spec/include/task-spec/device_specific.h +++ b/lib/task-spec/include/task-spec/device_specific.h @@ -25,6 +25,22 @@ struct DeviceSpecific { return this->tie() != other.tie(); } + bool operator<(DeviceSpecific const &other) const { + return this->tie() < other.tie(); + } + + bool operator<=(DeviceSpecific const &other) const { + return this->tie() <= other.tie(); + } + + bool operator>(DeviceSpecific const &other) const { + return this->tie() > other.tie(); + } + + bool operator>=(DeviceSpecific const &other) const { + return this->tie() >= other.tie(); + } + T const *get(device_id_t curr_device_idx) const { ASSERT(curr_device_idx == this->device_idx); return (T const *)this->ptr.get(); diff --git a/lib/task-spec/include/task-spec/device_specific_per_device_op_state.dtg.toml b/lib/task-spec/include/task-spec/device_specific_per_device_op_state.dtg.toml index 4435a472ce..0a32b29b25 100644 --- a/lib/task-spec/include/task-spec/device_specific_per_device_op_state.dtg.toml +++ b/lib/task-spec/include/task-spec/device_specific_per_device_op_state.dtg.toml @@ -4,6 +4,7 @@ type = "variant" features = [ "eq", "hash", + "ord", "fmt", ] diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_layer_guid_t.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_layer_guid_t.dtg.toml index 8def0ec5fb..5200bfc6a6 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_layer_guid_t.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_layer_guid_t.dtg.toml @@ -3,6 +3,7 @@ name = "dynamic_layer_guid_t" type = "variant" features = [ "eq", + "ord", "hash", "fmt", "json", diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_attrs.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_attrs.dtg.toml index 73c023fd40..110c963383 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_attrs.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_attrs.dtg.toml @@ -4,6 +4,7 @@ type = "struct" features = [ "eq", "hash", + "ord", "fmt", ] @@ -11,6 +12,7 @@ includes = [ "", "task-spec/dynamic_graph/dynamic_task_type.dtg.h", "pcg/machine_space_coordinate.dtg.h", + "utils/nonempty_set/nonempty_set.h", "pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h", "task-spec/dynamic_graph/dynamic_layer_guid_t.dtg.h", "task-spec/dynamic_graph/training_operation_attrs.dtg.h", @@ -26,10 +28,10 @@ name = "task_type" type = "std::optional<::FlexFlow::DynamicTaskType>" [[fields]] -name = "device_coord" -type = "std::optional<::FlexFlow::MachineSpaceCoordinate>" +name = "device_coords" +type = "std::optional<::FlexFlow::nonempty_set<::FlexFlow::MachineSpaceCoordinate>>" docstring = ''' -\note Right now the \c device_coord for a copy node is sort of meaningless +\note Right now the \c device_coords for a copy node is sort of meaningless because we have one controller issuing all copies for the entire graph, no matter where they are. However the intention is this to be the "owner" or "issuer" of the copy, which matters a lot more down the road once we write the diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_invocation.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_invocation.dtg.toml index 07060106c0..4d6d27444a 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_invocation.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_invocation.dtg.toml @@ -3,6 +3,7 @@ name = "DynamicNodeInvocation" type = "struct" features = [ "eq", + "ord", "fmt", "hash", ] @@ -15,13 +16,13 @@ includes = [ ] src_includes = [ - "utils/hash/unordered_map.h", - "utils/fmt/unordered_map.h", + "utils/hash/map.h", + "utils/fmt/map.h", ] [[fields]] name = "inputs" -type = "std::unordered_map<::FlexFlow::DynamicTensorSlot, ::FlexFlow::DynamicValueAttrs>" +type = "std::map<::FlexFlow::DynamicTensorSlot, ::FlexFlow::DynamicValueAttrs>" [[fields]] name = "node_attrs" @@ -29,4 +30,4 @@ type = "::FlexFlow::DynamicNodeAttrs" [[fields]] name = "outputs" -type = "std::unordered_map<::FlexFlow::DynamicTensorSlot, ::FlexFlow::DynamicValueAttrs>" +type = "std::map<::FlexFlow::DynamicTensorSlot, ::FlexFlow::DynamicValueAttrs>" diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_invocation_sharding_info.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_invocation_sharding_info.dtg.toml index a59aba92d7..00d98a2e6c 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_invocation_sharding_info.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_invocation_sharding_info.dtg.toml @@ -1,7 +1,6 @@ namespace = "FlexFlow" name = "DynamicNodeInvocationShardingInfo" type = "struct" -#include "task-spec/dynamic_graph/shard_expansion.h" features = [ "eq", "ord", @@ -14,6 +13,7 @@ includes = [ "pcg/machine_space_coordinate.dtg.h", "task-spec/dynamic_graph/dynamic_tensor_slot.dtg.h", "task-spec/dynamic_graph/dynamic_value_attrs_sharding_info.dtg.h", + "utils/nonempty_set/nonempty_set.h", ] src_includes = [ @@ -22,8 +22,8 @@ src_includes = [ ] [[fields]] -name = "device_coord" -type = "::FlexFlow::MachineSpaceCoordinate" +name = "device_coords" +type = "::FlexFlow::nonempty_set<::FlexFlow::MachineSpaceCoordinate>" [[fields]] name = "value_sharding" diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.toml index 85f8f299a4..bfe30a7a5e 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.toml @@ -3,6 +3,7 @@ name = "DynamicTensorAccessor" type = "variant" features = [ "eq", + "ord", "fmt", "hash", ] diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_guid_t.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_guid_t.dtg.toml index c9171b928b..56f8ae4359 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_guid_t.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_guid_t.dtg.toml @@ -3,6 +3,7 @@ name = "dynamic_tensor_guid_t" type = "variant" features = [ "eq", + "ord", "hash", "fmt", "json", diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_slot.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_slot.dtg.toml index 378582f428..851b8dc83d 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_slot.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_slot.dtg.toml @@ -12,6 +12,7 @@ features = [ includes = [ "op-attrs/tensor_slot_name.dtg.h", "task-spec/dynamic_graph/dynamic_tensor_role.dtg.h", + "pcg/machine_space_coordinate.dtg.h", "", ] @@ -27,3 +28,13 @@ type = "::FlexFlow::TensorSlotName" [[fields]] name = "slot_tensor_role" type = "std::optional<::FlexFlow::DynamicTensorRole>" + +[[fields]] +name = "task_shard" +type = "std::optional<::FlexFlow::MachineSpaceCoordinate>" +docstring = ''' +\brief For representing parallel operators such as \ref ReplicateAttrs as a single operator with multiple outputs instead of fully shard expanding. + +This is done for the convenience of the runtime, as the \ref ReplicateAttrs is ultimately executed as a single NCCL task/call, so shard-expanding this operator is ultimately counter-productive. +Since the output values, however, do need to be shard-expanded for the rest of the system to work, we make Replicate a single operator with multiple outputs, each with the same \ref TensorSlotName and \ref DynamicTensorRole, but with different \ref MachineSpaceCoordinate ""s to identify which node is ultimately producing that value. +''' diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.dtg.toml index add72764f1..1e13736fc7 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.dtg.toml @@ -3,6 +3,7 @@ name = "DynamicValueAttrs" type = "struct" features = [ "eq", + "ord", "fmt", "hash", ] diff --git a/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_attrs.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_attrs.dtg.toml index 3c43e1d637..c8ea415ee2 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_attrs.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_attrs.dtg.toml @@ -3,6 +3,7 @@ name = "SerializableDynamicNodeAttrs" type = "struct" features = [ "eq", + "ord", "hash", "fmt", "json", @@ -15,6 +16,7 @@ includes = [ "pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h", "task-spec/dynamic_graph/dynamic_layer_guid_t.dtg.h", "task-spec/dynamic_graph/training_operation_attrs.dtg.h", + "utils/nonempty_set/nonempty_set.h", ] src_includes = [ @@ -27,8 +29,8 @@ name = "task_type" type = "std::optional<::FlexFlow::DynamicTaskType>" [[fields]] -name = "device_coord" -type = "std::optional<::FlexFlow::MachineSpaceCoordinate>" +name = "device_coords" +type = "std::optional<::FlexFlow::nonempty_set<::FlexFlow::MachineSpaceCoordinate>>" [[fields]] name = "mapping" diff --git a/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_invocation.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_invocation.dtg.toml index 01f4cc8876..6051a49876 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_invocation.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_invocation.dtg.toml @@ -3,6 +3,7 @@ name = "SerializableDynamicNodeInvocation" type = "struct" features = [ "eq", + "ord", "fmt", "hash", "json", @@ -16,13 +17,13 @@ includes = [ ] src_includes = [ - "utils/hash/unordered_map.h", - "utils/fmt/unordered_map.h", + "utils/hash/map.h", + "utils/fmt/map.h", ] [[fields]] name = "inputs" -type = "std::unordered_map<::FlexFlow::DynamicTensorSlot, ::FlexFlow::SerializableDynamicValueAttrs>" +type = "std::map<::FlexFlow::DynamicTensorSlot, ::FlexFlow::SerializableDynamicValueAttrs>" [[fields]] name = "node_attrs" @@ -30,4 +31,4 @@ type = "::FlexFlow::SerializableDynamicNodeAttrs" [[fields]] name = "outputs" -type = "std::unordered_map<::FlexFlow::DynamicTensorSlot, ::FlexFlow::SerializableDynamicValueAttrs>" +type = "std::map<::FlexFlow::DynamicTensorSlot, ::FlexFlow::SerializableDynamicValueAttrs>" diff --git a/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_value_attrs.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_value_attrs.dtg.toml index d3cab6ecdb..d05d8e011e 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_value_attrs.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_value_attrs.dtg.toml @@ -3,6 +3,7 @@ name = "SerializableDynamicValueAttrs" type = "struct" features = [ "eq", + "ord", "hash", "fmt", "json", diff --git a/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.dtg.toml index 8f8f6467c8..2c4e739571 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.dtg.toml @@ -3,6 +3,7 @@ name = "TrainingOperationAttrs" type = "variant" features = [ "eq", + "ord", "hash", "fmt", "json", diff --git a/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc b/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc index f24dd27da2..ebf84f2d81 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc @@ -131,7 +131,7 @@ std::unordered_set copies_for_invocation_inputs( return map_dynamic_value_attrs_for_task_group(slot, value, mapping); }; - std::unordered_map mapped_inputs = + std::map mapped_inputs = map_values2(i.inputs, map_tensor); std::unordered_set result; @@ -152,6 +152,7 @@ std::unordered_set copies_for_invocation_inputs( DynamicTensorSlot{ TensorSlotName::INPUT, slot.slot_tensor_role, + /*task_shard=*/std::nullopt, }, filtered_source, }, @@ -170,8 +171,11 @@ std::unordered_set copies_for_invocation_inputs( /*outputs=*/ { { - DynamicTensorSlot{TensorSlotName::OUTPUT, - slot.slot_tensor_role}, + DynamicTensorSlot{ + TensorSlotName::OUTPUT, + slot.slot_tensor_role, + /*task_shard=*/std::nullopt, + }, filtered_use, }, }, @@ -196,9 +200,9 @@ std::unordered_set perform_copy_insertion_for_invocation( }; DynamicNodeInvocation mapped_i = [&] { - std::unordered_map mapped_inputs = + std::map mapped_inputs = map_values2(i.inputs, map_tensor); - std::unordered_map mapped_outputs = + std::map mapped_outputs = map_values2(i.outputs, map_tensor); DynamicNodeInvocation r = i; diff --git a/lib/task-spec/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc b/lib/task-spec/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc index a100c3adfb..38d49ef183 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc @@ -20,6 +20,8 @@ #include "utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_graph_input.dtg.h" #include "utils/many_to_one/many_to_one.h" #include "utils/containers/require_all_of.h" +#include "utils/containers/unordered_map_from_map.h" +#include "utils/containers/map_from_unordered.h" namespace FlexFlow { @@ -216,16 +218,16 @@ std::pair void { KwargNodeAddedResult added = result.add_node( invocation.node_attrs, - map_values(invocation.inputs, + map_values(unordered_map_from_map(invocation.inputs), [&](DynamicValueAttrs const &input) -> OpenKwargDataflowValue { return value_map.at_r(input); }), - invocation.outputs); + unordered_map_from_map(invocation.outputs)); node_map.equate(added.node, invocation); for (auto const &[k, v] : - zip_values_strict(invocation.outputs, added.outputs)) { + zip_values_strict(invocation.outputs, map_from_unordered(added.outputs))) { DynamicValueAttrs invocation_output = v.first; KwargDataflowOutput graph_output = v.second; value_map.equate( diff --git a/lib/task-spec/src/task-spec/dynamic_graph/loss_insertion.cc b/lib/task-spec/src/task-spec/dynamic_graph/loss_insertion.cc index 8066926262..8ff24b51ad 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/loss_insertion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/loss_insertion.cc @@ -18,6 +18,7 @@ LossInsertionResult perform_loss_insertion( LossAttrs const &loss_attrs, dynamic_tensor_guid_t logit_tensor, std::optional const &loss_mapping) { + DynamicValueAttrs logit_value = assert_unwrap( find_output_value_attrs(dg, logit_tensor, mk_dynamic_tensor_role_fwd())); @@ -29,6 +30,7 @@ LossInsertionResult perform_loss_insertion( /*accessor=*/std::nullopt, /*role=*/mk_dynamic_tensor_role_loss(), }; + DynamicValueAttrs logit_grad_value{ /*tensor_guid=*/logit_value.tensor_guid, /*parallel_tensor_shape=*/logit_value.parallel_tensor_shape, @@ -37,14 +39,25 @@ LossInsertionResult perform_loss_insertion( /*accessor=*/std::nullopt, /*role=*/mk_dynamic_tensor_role_bwd(), }; + DynamicNodeInvocation loss_invocation{ /*inputs=*/{ - {DynamicTensorSlot{/*slot_name=*/TensorSlotName::INPUT, - /*slot_tensor_role=*/label_value.role}, - label_value}, - {DynamicTensorSlot{/*slot_name=*/TensorSlotName::LOGIT, - /*slot_tensor_role=*/logit_value.role}, - logit_value}, + { + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::INPUT, + /*slot_tensor_role=*/label_value.role, + /*task_shard=*/std::nullopt, + }, + label_value, + }, + { + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::LOGIT, + /*slot_tensor_role=*/logit_value.role, + /*task_shard=*/std::nullopt, + }, + logit_value, + }, }, /*node_attrs=*/ DynamicNodeAttrs{ @@ -57,11 +70,17 @@ LossInsertionResult perform_loss_insertion( }, /*outputs=*/ { - {DynamicTensorSlot{/*slot_name=*/TensorSlotName::LOGIT, - /*slot_tensor_role=*/logit_grad_value.role}, - logit_grad_value}, + { + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::LOGIT, + /*slot_tensor_role=*/logit_grad_value.role, + /*task_shard=*/std::nullopt, + }, + logit_grad_value, + }, }, }; + DynamicOpenDataflowGraph result = dg; result.invocations.insert(loss_invocation); return LossInsertionResult{result, label_value, logit_grad_value}; diff --git a/lib/task-spec/src/task-spec/dynamic_graph/machine_slicing.cc b/lib/task-spec/src/task-spec/dynamic_graph/machine_slicing.cc index 0a22015ddf..6dd73bed7d 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/machine_slicing.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/machine_slicing.cc @@ -8,9 +8,9 @@ std::unordered_set DynamicNodeInvocation const &invocation, MachineSpaceCoordinate const &device_coord) { - ASSERT(invocation.node_attrs.device_coord.has_value()); + ASSERT(invocation.node_attrs.device_coords.has_value()); - if (invocation.node_attrs.device_coord.value() == device_coord) { + if (contains(invocation.node_attrs.device_coords.value(), device_coord)) { return {invocation}; } else { return {}; diff --git a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_cg.cc b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_cg.cc index 7fe3927fd1..740a93415e 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_cg.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_cg.cc @@ -11,6 +11,7 @@ #include #include #include +#include "utils/containers/map_from_unordered.h" namespace FlexFlow { @@ -39,6 +40,7 @@ DynamicOpenDataflowGraph DynamicTensorSlot{ /*slot_name=*/slot_name, /*slot_tensor_role=*/std::nullopt, + /*task_shard=*/std::nullopt, }, DynamicValueAttrs{ /*tensor_guid=*/dynamic_tensor_guid_t{tensor}, @@ -50,6 +52,7 @@ DynamicOpenDataflowGraph }, }; }); + std::unordered_map result_outputs = transform( get_outgoing_tensors(cg, layer), @@ -59,6 +62,7 @@ DynamicOpenDataflowGraph DynamicTensorSlot{ /*slot_name=*/slot_name, /*slot_tensor_role=*/std::nullopt, + /*task_shard=*/std::nullopt, }, DynamicValueAttrs{ /*tensor_guid=*/dynamic_tensor_guid_t{tensor}, @@ -71,7 +75,7 @@ DynamicOpenDataflowGraph }; }); - result.invocations.emplace(result_inputs, result_attrs, result_outputs); + result.invocations.emplace(map_from_unordered(result_inputs), result_attrs, map_from_unordered(result_outputs)); } return result; diff --git a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc index 391ebaff3b..9c7638440f 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc @@ -43,6 +43,7 @@ DynamicNodeInvocation make_dynamic_node_invocation_from_mapped( DynamicTensorSlot{ /*slot_name=*/slot_name, /*slot_tensor_role=*/std::nullopt, + /*task_shard=*/std::nullopt, }, DynamicValueAttrs{ /*tensor_guid=*/dynamic_tensor_guid_t{tensor.guid}, @@ -62,9 +63,9 @@ DynamicNodeInvocation make_dynamic_node_invocation_from_mapped( transform(invocation_info.outgoing, lift_kv_pair); DynamicNodeInvocation invocation = DynamicNodeInvocation{ - /*inputs=*/unordered_map_from_map(result_inputs), + /*inputs=*/result_inputs, /*node_attrs=*/result_attrs, - /*outputs=*/unordered_map_from_map(result_outputs), + /*outputs=*/result_outputs, }; return invocation; diff --git a/lib/task-spec/src/task-spec/dynamic_graph/serializable_dynamic_node_attrs.cc b/lib/task-spec/src/task-spec/dynamic_graph/serializable_dynamic_node_attrs.cc index d613194d14..67ae7b58f3 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/serializable_dynamic_node_attrs.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/serializable_dynamic_node_attrs.cc @@ -7,7 +7,7 @@ SerializableDynamicNodeAttrs dynamic_node_attrs_to_serializable(DynamicNodeAttrs const &attrs) { return SerializableDynamicNodeAttrs{ /*task_type=*/attrs.task_type, - /*device_coord=*/attrs.device_coord, + /*device_coords=*/attrs.device_coords, /*mapping=*/attrs.mapping, /*op_attrs=*/attrs.op_attrs, /*layer_guid=*/attrs.layer_guid, @@ -18,7 +18,7 @@ DynamicNodeAttrs dynamic_node_attrs_from_serializable( SerializableDynamicNodeAttrs const &attrs) { return DynamicNodeAttrs{ /*task_type=*/attrs.task_type, - /*device_coord=*/attrs.device_coord, + /*device_coords=*/attrs.device_coords, /*mapping=*/attrs.mapping, /*op_attrs=*/attrs.op_attrs, /*layer_guid=*/attrs.layer_guid, diff --git a/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc b/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc index badd376a8b..db9484e361 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc @@ -11,11 +11,19 @@ #include "task-spec/dynamic_graph/dynamic_node_invocation.h" #include "utils/containers/map_from_unordered.h" #include "utils/one_to_many/one_to_many_filter_keys.h" +#include "task-spec/dynamic_graph/training_operation_attrs.h" +#include "utils/one_to_many/require_one_to_many_is_bijection.h" +#include "pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.h" +#include "utils/bidict/algorithms/bidict_filter_values.h" +#include "task-spec/dynamic_graph/dynamic_tensor_role.h" +#include "utils/containers/merge_disjoint_maps.h" +#include "utils/containers/map_keys.h" +#include "utils/containers/require_only_key.h" namespace FlexFlow { bool node_is_shard_expanded(DynamicNodeAttrs const &n) { - return n.device_coord.has_value(); + return n.device_coords.has_value(); } bool node_is_ready_for_shard_expansion(DynamicNodeAttrs const &n) { @@ -138,16 +146,15 @@ static DynamicNodeInvocationShardingInfo invocation_sharding_info_for_binding( DynamicNodeAttrs expanded_node_attrs = [&]() { DynamicNodeAttrs result = i.node_attrs; - result.device_coord = machine_coord; + result.device_coords = nonempty_set{machine_coord}; return result; }(); return DynamicNodeInvocationShardingInfo{ - /*device_coord=*/machine_coord, - /*value_sharding=*/map_from_unordered( - map_values2( + /*device_coord=*/nonempty_set{machine_coord}, + /*value_sharding=*/map_values2( binary_merge_disjoint_maps(i.inputs, i.outputs), - shard_expand_value_attrs)), + shard_expand_value_attrs), }; } @@ -177,7 +184,7 @@ static DynamicNodeInvocation shard_invocation_for_binding( DynamicNodeAttrs expanded_node_attrs = [&]() { DynamicNodeAttrs result = i.node_attrs; - result.device_coord = machine_coord; + result.device_coords = nonempty_set{machine_coord}; return result; }(); @@ -219,6 +226,121 @@ static std::set }); } +static std::set + generate_shard_expansion_for_fwd_replicate(DynamicNodeInvocation const &i) { + ASSERT(i.node_attrs.task_type == DynamicTaskType::FWD); + + MappedOperatorTaskGroup node_mapping = assert_unwrap(i.node_attrs.mapping); + + DynamicTensorSlot expected_input_slot = DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::INPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_fwd(), + /*task_shard=*/std::nullopt, + }; + + DynamicValueAttrs input = require_only_key(i.inputs, expected_input_slot); + + DynamicTensorSlot expected_output_slot = DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::OUTPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_fwd(), + /*task_shard=*/std::nullopt, + }; + + DynamicValueAttrs output = require_only_key(i.outputs, expected_output_slot); + + bidict + input_value_mapping = require_one_to_many_is_bijection( + assert_unwrap(input.mapping)); + + std::set input_tensor_shards = set_of(input_value_mapping.left_values()); + + bidict + output_value_mapping = require_one_to_many_is_bijection( + assert_unwrap(output.mapping)); + + auto get_task_shard_machine_coords_for_input_tensor_shard + = [&](ParallelTensorSpaceCoordinate const &input_tensor_shard) + -> nonempty_set + { + bidict dependent_on_input_tensor_shard + = bidict_filter_values( + node_mapping.get_shard_bindings(), + [&](OperatorAtomicTaskShardBinding const &b) -> bool { + return ptensor_space_coord_for_slot_name(b, TensorSlotName::INPUT) == input_tensor_shard; + }); + + return nonempty_set(set_of(dependent_on_input_tensor_shard.left_values())); + }; + + auto invocation_sharding_info_for_input_tensor_shard = [&](ParallelTensorSpaceCoordinate const &c) + -> DynamicNodeInvocationShardingInfo + { + nonempty_set task_shard_machine_coords = + get_task_shard_machine_coords_for_input_tensor_shard(c); + + std::map output_sharding_infos = + generate_map(task_shard_machine_coords.unwrap_as_set(), + [&](MachineSpaceCoordinate const &mc) + -> DynamicValueAttrsShardingInfo + { + ParallelTensorSpaceCoordinate pc = output_value_mapping.at_r(mc); + + return DynamicValueAttrsShardingInfo{ + /*shard_coord=*/pc, + /*mapping=*/OneToMany{ + { + pc, + {mc}, + }, + }, + }; + }); + + std::map keyed_output_sharding_infos = + map_keys(output_sharding_infos, + [&](MachineSpaceCoordinate const &mc) -> DynamicTensorSlot { + return DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::OUTPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_fwd(), + /*task_shard=*/mc, + }; + }); + + DynamicTensorSlot input_slot = DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::INPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_fwd(), + /*task_shard=*/std::nullopt, + }; + + DynamicValueAttrsShardingInfo input_sharding_info = DynamicValueAttrsShardingInfo{ + /*shard_coord=*/c, + /*mapping=*/OneToMany{ + { + c, + {input_value_mapping.at_l(c)}, + }, + }, + }; + + std::map sharding_infos = + binary_merge_disjoint_maps( + keyed_output_sharding_infos, + std::map{ + { + input_slot, + input_sharding_info, + }, + }); + + return DynamicNodeInvocationShardingInfo{ + /*device_coords=*/task_shard_machine_coords, + /*value_sharding=*/sharding_infos, + }; + }; + + return transform(input_tensor_shards, invocation_sharding_info_for_input_tensor_shard); +} + std::unordered_set perform_shard_expansion_for_invocation(DynamicNodeInvocation const &i) { @@ -259,10 +381,10 @@ void require_graph_is_ready_for_shard_expansion(DynamicOpenDataflowGraph const & DynamicNodeAttrs apply_dynamic_node_attrs_sharding_info( DynamicNodeAttrs const &node_attrs, - MachineSpaceCoordinate const &device_coord) + nonempty_set const &device_coords) { DynamicNodeAttrs result = node_attrs; - result.device_coord = device_coord; + result.device_coords = device_coords; return result; } @@ -283,14 +405,17 @@ DynamicNodeInvocation apply_dynamic_node_invocation_sharding_info( { require_invocation_is_ready_for_shard_expansion(invocation); - auto shard_value = [&](DynamicTensorSlot const &slot, DynamicValueAttrs const &value_attrs) -> DynamicValueAttrs { + auto shard_value = [&](DynamicTensorSlot const &slot, DynamicValueAttrs const &value_attrs) + -> DynamicValueAttrs + { DynamicValueAttrsShardingInfo sharding_info = invocation_sharding_info.value_sharding.at(slot); return apply_dynamic_value_attrs_sharding_info(value_attrs, sharding_info); }; DynamicNodeInvocation result = DynamicNodeInvocation{ /*inputs=*/map_values2(invocation.inputs, shard_value), - /*node_attrs=*/apply_dynamic_node_attrs_sharding_info(invocation.node_attrs, invocation_sharding_info.device_coord), + /*node_attrs=*/apply_dynamic_node_attrs_sharding_info( + invocation.node_attrs, invocation_sharding_info.device_coords), /*outputs=*/map_values2(invocation.outputs, shard_value), }; @@ -303,11 +428,14 @@ std::unordered_set { require_invocation_is_ready_for_shard_expansion(i); - if (i.node_attrs.op_attrs.has_value() && - i.node_attrs.op_attrs.value().is_copy()) { + if (i.node_attrs.op_attrs.value().is_copy()) { return unordered_set_of(generate_shard_expansion_for_copy(i)); } + if (training_op_attrs_has_op_type(i.node_attrs.op_attrs.value(), OperatorType::REPLICATE)) { + return unordered_set_of(generate_shard_expansion_for_fwd_replicate(i)); + } + MappedOperatorTaskGroup mapping = assert_unwrap(i.node_attrs.mapping); std::unordered_set shard_machine_coords = diff --git a/lib/task-spec/src/task-spec/dynamic_graph/update_insertion.cc b/lib/task-spec/src/task-spec/dynamic_graph/update_insertion.cc index 58a32db6c1..51a79cff59 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/update_insertion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/update_insertion.cc @@ -79,7 +79,7 @@ static DynamicNodeInvocation get_update_invocation_for_invocation( /*inputs=*/map_from_pairs( transform(tensor_roles, create_binding_for_role)), /*node_attrs=*/update_node_attrs, - /*outputs=*/std::unordered_map{}, + /*outputs=*/std::map{}, }; } diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/copy_insertion.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/copy_insertion.cc index 31de844555..05324c2195 100644 --- a/lib/task-spec/test/src/task-spec/dynamic_graph/copy_insertion.cc +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/copy_insertion.cc @@ -28,6 +28,7 @@ TEST_SUITE(FF_TEST_SUITE) { return DynamicTensorSlot{ /*slot_name=*/slot_name, /*slot_tensor_role=*/mk_dynamic_tensor_role_fwd(), + /*task_shard=*/std::nullopt, }; }; diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc index 49b8d4a77a..96523b6c31 100644 --- a/lib/task-spec/test/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc @@ -58,33 +58,40 @@ TEST_SUITE(FF_TEST_SUITE) { }; DynamicNodeInvocation invocation_1 = DynamicNodeInvocation{ - /*inputs=*/std::unordered_map{ - {DynamicTensorSlot{ - /*slot_name=*/TensorSlotName::INPUT, - /*slot_tensor_role=*/std::nullopt, - }, - value_1}, + /*inputs=*/std::map{ + { + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::INPUT, + /*slot_tensor_role=*/std::nullopt, + /*task_shard=*/std::nullopt, + }, + value_1, + }, }, /*node_attrs=*/node_attrs, /*outputs=*/ - std::unordered_map{ - {DynamicTensorSlot{ - /*slot_name=*/TensorSlotName::OUTPUT, - /*slot_tensor_role=*/std::nullopt, - }, - value_2}, + std::map{ + { + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::OUTPUT, + /*slot_tensor_role=*/std::nullopt, + /*task_shard=*/std::nullopt, + }, + value_2, + }, }, }; DynamicNodeInvocation invocation_2 = DynamicNodeInvocation{ - /*inputs=*/std::unordered_map{}, + /*inputs=*/std::map{}, /*node_attrs=*/node_attrs, /*outputs=*/ - std::unordered_map{ + std::map{ { DynamicTensorSlot{ /*slot_name=*/TensorSlotName::OUTPUT, /*slot_tensor_role=*/std::nullopt, + /*task_shard=*/std::nullopt, }, value_3, }, @@ -92,11 +99,12 @@ TEST_SUITE(FF_TEST_SUITE) { }; DynamicNodeInvocation invocation_3 = DynamicNodeInvocation{ - /*inputs=*/std::unordered_map{ + /*inputs=*/std::map{ { DynamicTensorSlot{ /*slot_name=*/TensorSlotName::INPUT, /*slot_tensor_role=*/std::nullopt, + /*task_shard=*/std::nullopt, }, value_1, }, @@ -104,6 +112,7 @@ TEST_SUITE(FF_TEST_SUITE) { DynamicTensorSlot{ /*slot_name=*/TensorSlotName::WEIGHT, /*slot_tensor_role=*/std::nullopt, + /*task_shard=*/std::nullopt, }, value_2, }, @@ -111,12 +120,13 @@ TEST_SUITE(FF_TEST_SUITE) { DynamicTensorSlot{ /*slot_name=*/TensorSlotName::BIAS, /*slot_tensor_role=*/std::nullopt, + /*task_shard=*/std::nullopt, }, value_1, }, }, /*node_attrs=*/node_attrs, - /*outputs=*/std::unordered_map{}, + /*outputs=*/std::map{}, }; std::unordered_set invocation_set = { diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/machine_slicing.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/machine_slicing.cc index 40b3460ee5..789d89a676 100644 --- a/lib/task-spec/test/src/task-spec/dynamic_graph/machine_slicing.cc +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/machine_slicing.cc @@ -58,6 +58,7 @@ TEST_SUITE(FF_TEST_SUITE) { return DynamicTensorSlot{ /*slot_name=*/slot_name, /*slot_tensor_role=*/std::nullopt, + /*task_shard=*/std::nullopt, }; }; @@ -112,7 +113,7 @@ TEST_SUITE(FF_TEST_SUITE) { /*node_attrs=*/ DynamicNodeAttrs{ /*task_type=*/std::nullopt, - /*device_coord=*/mc2, + /*device_coords=*/nonempty_set{mc2}, /*mapping=*/std::nullopt, /*op_attrs=*/std::nullopt, /*layer_guid=*/ @@ -139,7 +140,7 @@ TEST_SUITE(FF_TEST_SUITE) { /*node_attrs=*/ DynamicNodeAttrs{ /*task_type=*/std::nullopt, - /*device_coord=*/mc1, + /*device_coord=*/nonempty_set{mc1}, /*mapping=*/std::nullopt, /*op_attrs=*/std::nullopt, /*layer_guid=*/ @@ -173,7 +174,7 @@ TEST_SUITE(FF_TEST_SUITE) { /*node_attrs=*/ DynamicNodeAttrs{ /*task_type=*/std::nullopt, - /*device_coord=*/mc2, + /*device_coord=*/nonempty_set{mc2}, /*mapping=*/std::nullopt, /*op_attrs=*/std::nullopt, /*layer_guid=*/ diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc index 9f8aeee726..2e14c88654 100644 --- a/lib/task-spec/test/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc @@ -11,7 +11,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("Replicate") { MachineSpaceCoordinate gpu0 = MachineSpaceCoordinate{0_n, 0_n, DeviceType::GPU}; MachineSpaceCoordinate gpu1 = MachineSpaceCoordinate{0_n, 1_n, DeviceType::GPU}; - + ParallelTensorSpaceCoordinate tensor_coord0 = ParallelTensorSpaceCoordinate{ /*sum_component=*/0_n, /*discard_copy_component=*/0_n, @@ -78,7 +78,7 @@ TEST_SUITE(FF_TEST_SUITE) { KwargDataflowOutput{ Node{0}, TensorSlotName::OUTPUT, - }, + }, }; MappedParallelLayerInvocationInfo input = MappedParallelLayerInvocationInfo{ @@ -124,6 +124,7 @@ TEST_SUITE(FF_TEST_SUITE) { DynamicTensorSlot{ TensorSlotName::INPUT, /*slot_tensor_role=*/std::nullopt, + /*task_shard=*/std::nullopt, }, DynamicValueAttrs{ /*tensor_guid=*/dynamic_tensor_guid_t{input_tensor_guid}, @@ -148,6 +149,7 @@ TEST_SUITE(FF_TEST_SUITE) { DynamicTensorSlot{ TensorSlotName::OUTPUT, /*slot_tensor_role=*/std::nullopt, + /*task_shard=*/std::nullopt, }, DynamicValueAttrs{ /*tensor_guid=*/dynamic_tensor_guid_t{output_tensor_guid}, @@ -165,7 +167,7 @@ TEST_SUITE(FF_TEST_SUITE) { } // SUBCASE("standard op") { - // + // // } } diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc index bf88d5ec38..90fbdec5f7 100644 --- a/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc @@ -32,6 +32,7 @@ TEST_SUITE(FF_TEST_SUITE) { return DynamicTensorSlot{ /*slot_name=*/slot_name, /*slot_tensor_role=*/role, + /*task_shard=*/std::nullopt, }; }; @@ -136,6 +137,7 @@ TEST_SUITE(FF_TEST_SUITE) { return DynamicTensorSlot{ /*slot_name=*/slot_name, /*slot_tensor_role=*/role, + /*task_shard=*/std::nullopt, }; }; @@ -352,37 +354,42 @@ TEST_SUITE(FF_TEST_SUITE) { std::unordered_set invocation_set = { DynamicNodeInvocation{ - /*inputs=*/std::unordered_map{}, + /*inputs=*/std::map{}, /*node_attrs=*/n1, /*outputs=*/ - std::unordered_map{ + std::map{ { DynamicTensorSlot{ /*slot_name=*/TensorSlotName::OUTPUT, /*slot_tensor_role=*/std::nullopt, + /*task_shard=*/std::nullopt, }, v1, }, }, }, DynamicNodeInvocation{ - /*inputs=*/std::unordered_map{ - {DynamicTensorSlot{ + /*inputs=*/std::map{ + { + DynamicTensorSlot{ /*slot_name=*/TensorSlotName::INPUT, /*slot_tensor_role=*/std::nullopt, - }, - v1}, + /*task_shard=*/std::nullopt, + }, + v1, + }, }, /*node_attrs=*/n2, /*outputs=*/ - std::unordered_map{ - {DynamicTensorSlot{ - /*slot_name=*/TensorSlotName::OUTPUT, - /*slot_tensor_role=*/std::nullopt, - }, - v2}, + std::map{ + { + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::OUTPUT, + /*slot_tensor_role=*/std::nullopt, + /*task_shard=*/std::nullopt, + }, + v2, + }, }, }, }; @@ -413,48 +420,51 @@ TEST_SUITE(FF_TEST_SUITE) { std::unordered_set invocation_set = { DynamicNodeInvocation{ - /*inputs=*/std::unordered_map{}, + /*inputs=*/std::map{}, /*node_attrs=*/n1_fwd, /*outputs=*/ - std::unordered_map{ + std::map{ std::pair{ DynamicTensorSlot{ /*slot_name=*/TensorSlotName::OUTPUT, /*slot_tensor_role=*/mk_dynamic_tensor_role_fwd(), + /*task_shard=*/std::nullopt, }, v1_activation, }, }, }, DynamicNodeInvocation{ - /*inputs=*/std::unordered_map{ + /*inputs=*/std::map{ std::pair{ DynamicTensorSlot{ TensorSlotName::INPUT, mk_dynamic_tensor_role_fwd(), + /*task_shard=*/std::nullopt, }, v1_activation, }, }, /*node_attrs=*/n2_fwd, /*outputs=*/ - std::unordered_map{ + std::map{ std::pair{ DynamicTensorSlot{ TensorSlotName::OUTPUT, mk_dynamic_tensor_role_fwd(), + /*task_shard=*/std::nullopt, }, v2_activation, }, }, }, DynamicNodeInvocation{ - /*inputs=*/std::unordered_map{ + /*inputs=*/std::map{ std::pair{ DynamicTensorSlot{ TensorSlotName::INPUT, mk_dynamic_tensor_role_fwd(), + /*task_shard=*/std::nullopt, }, v1_activation, }, @@ -462,6 +472,7 @@ TEST_SUITE(FF_TEST_SUITE) { DynamicTensorSlot{ TensorSlotName::OUTPUT, mk_dynamic_tensor_role_fwd(), + /*task_shard=*/std::nullopt, }, v2_activation, }, @@ -469,17 +480,19 @@ TEST_SUITE(FF_TEST_SUITE) { DynamicTensorSlot{ TensorSlotName::OUTPUT, mk_dynamic_tensor_role_bwd(), + /*task_shard=*/std::nullopt, }, v2_gradient, }, }, /*node_attrs=*/n2_bwd, /*outputs=*/ - std::unordered_map{ + std::map{ std::pair{ DynamicTensorSlot{ TensorSlotName::INPUT, mk_dynamic_tensor_role_bwd(), + /*task_shard=*/std::nullopt, }, v1_gradient, }, diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/shard_expansion.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/shard_expansion.cc index 19c21f5f89..9204e60c58 100644 --- a/lib/task-spec/test/src/task-spec/dynamic_graph/shard_expansion.cc +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/shard_expansion.cc @@ -9,6 +9,8 @@ #include "op-attrs/ops/element_unary.h" #include "utils/one_to_many/one_to_many_filter_keys.h" #include "utils/one_to_many/one_to_many_filter_values.h" +#include "utils/containers/map_from_pairs.h" +#include "utils/containers/binary_merge_disjoint_maps.h" using namespace ::FlexFlow; @@ -36,10 +38,12 @@ static ParallelTensorSpaceCoordinate mk_pt_coord(nonnegative_int idx1, }; }; -DynamicTensorSlot mk_slot(TensorSlotName const &slot_name) { +DynamicTensorSlot mk_slot(TensorSlotName const &slot_name, + std::optional const &task_shard = std::nullopt) { return DynamicTensorSlot{ /*slot_name=*/slot_name, /*slot_tensor_role=*/std::nullopt, + /*task_shard=*/task_shard, }; }; @@ -245,7 +249,7 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelTensorSpaceCoordinate const &output_2_shard_coord) -> DynamicNodeInvocationShardingInfo { return DynamicNodeInvocationShardingInfo{ - /*device_coord=*/device_coord, + /*device_coord=*/nonempty_set{device_coord}, /*value_sharding=*/{ mk_sharding_info(TensorSlotName::INPUT, input_shard_coord, mapped_task_group, device_coord), mk_sharding_info(TensorSlotName::WEIGHT, weight_shard_coord, mapped_task_group, device_coord), @@ -329,7 +333,7 @@ TEST_SUITE(FF_TEST_SUITE) { -> DynamicNodeInvocationShardingInfo { return DynamicNodeInvocationShardingInfo{ - /*device_coord=*/device_coord, + /*device_coord=*/nonempty_set{device_coord}, /*value_sharding=*/std::map{ { mk_slot(TensorSlotName::INPUT), @@ -360,6 +364,169 @@ TEST_SUITE(FF_TEST_SUITE) { mk_invocation_shard(mc2, pt2), }; + CHECK(result.size() == correct.size()); + CHECK(result == correct); + } + + SUBCASE("replicate operator") { + MachineSpaceCoordinate mc1 = mk_machine_coord(0_n, 0_n); + MachineSpaceCoordinate mc2 = mk_machine_coord(1_n, 0_n); + MachineSpaceCoordinate mc3 = mk_machine_coord(2_n, 0_n); + MachineSpaceCoordinate mc4 = mk_machine_coord(3_n, 0_n); + + ParallelTensorSpaceCoordinate pt1 = mk_pt_coord(0_n, 0_n, 0_n, 0_n); + ParallelTensorSpaceCoordinate pt2 = mk_pt_coord(0_n, 0_n, 0_n, 1_n); + ParallelTensorSpaceCoordinate pt3 = mk_pt_coord(0_n, 1_n, 0_n, 0_n); + ParallelTensorSpaceCoordinate pt4 = mk_pt_coord(0_n, 1_n, 0_n, 1_n); + + OneToMany src_binding{ + {pt1, {mc1}}, + {pt2, {mc2}}, + }; + + OneToMany dst_binding{ + {pt1, {mc1}}, + {pt2, {mc2}}, + {pt3, {mc3}}, + {pt4, {mc4}}, + }; + + auto mk_shard_binding = [&](ParallelTensorSpaceCoordinate const &c1, + ParallelTensorSpaceCoordinate const &c2) + -> OperatorAtomicTaskShardBinding { + return OperatorAtomicTaskShardBinding{ + /*tensor_coords=*/{ + { + TensorSlotName::INPUT, + c1, + }, + { + TensorSlotName::OUTPUT, + c2, + }, + }, + }; + }; + + MappedOperatorTaskGroup mapped_task_group = MappedOperatorTaskGroup{ + bidict{ + { + mc1, + mk_shard_binding(pt1, pt1), + }, + { + mc2, + mk_shard_binding(pt1, pt2), + }, + { + mc3, + mk_shard_binding(pt2, pt3), + }, + { + mc4, + mk_shard_binding(pt2, pt4), + }, + }, + }; + + DynamicNodeInvocation input = DynamicNodeInvocation{ + /*inputs=*/{ + { + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::INPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_fwd(), + /*task_shard=*/std::nullopt, + }, + mk_value(0, TensorSlotName::OUTPUT, src_binding, std::nullopt), + }, + }, + /*node_attrs=*/ + DynamicNodeAttrs{ + /*task_type=*/DynamicTaskType::FWD, + /*device_coords=*/std::nullopt, + /*mapping=*/mapped_task_group, + /*op_attrs=*/TrainingOperationAttrs{ + PCGOperatorAttrs{ + ReplicateAttrs{ + /*replicate_degree=*/2_p, + }, + }, + }, + /*layer_guid=*/dynamic_layer_guid_t{parallel_layer_guid_t{Node{20}}}, + /*per_device_op_state=*/std::nullopt, + }, + /*outputs=*/ + { + { + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::OUTPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_fwd(), + /*task_shard=*/std::nullopt, + }, + mk_value(20, TensorSlotName::OUTPUT, dst_binding, std::nullopt), + }, + }, + }; + + std::unordered_set result = + generate_shard_expansion_for_invocation(input); + + + auto mk_output_binding = [&](MachineSpaceCoordinate const &mc) + -> std::pair + { + return { + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::OUTPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_fwd(), + /*task_shard=*/mc, + }, + DynamicValueAttrsShardingInfo{ + dst_binding.at_r(mc), + one_to_many_filter_keys(dst_binding, + [&](ParallelTensorSpaceCoordinate const &pt_coord) -> bool { + return pt_coord == dst_binding.at_r(mc); + }), + }, + }; + }; + + auto mk_invocation_shard = + [&](nonempty_set const &device_coords, + ParallelTensorSpaceCoordinate const &input_shard_coord, + std::unordered_set const &output_task_shards) + -> DynamicNodeInvocationShardingInfo { + + return DynamicNodeInvocationShardingInfo{ + /*device_coords=*/device_coords, + /*value_sharding=*/ + binary_merge_disjoint_maps( + std::map{ + { + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::INPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_fwd(), + /*task_shard=*/std::nullopt, + }, + DynamicValueAttrsShardingInfo{ + input_shard_coord, + one_to_many_filter_keys( + src_binding, + [&](ParallelTensorSpaceCoordinate const &pt_coord) -> bool { + return pt_coord == input_shard_coord; + }), + }, + }, + }, + map_from_pairs(transform(output_task_shards, mk_output_binding))), + }; + }; + + std::unordered_set correct = { + mk_invocation_shard(nonempty_set{mc1, mc2}, pt1, {mc1, mc2}), + mk_invocation_shard(nonempty_set{mc3, mc4}, pt2, {mc3, mc4}), + }; + nlohmann::json result_json = result; nlohmann::json correct_json = correct; diff --git a/lib/utils/include/utils/containers/binary_merge_disjoint_maps.h b/lib/utils/include/utils/containers/binary_merge_disjoint_maps.h index 824fe77b39..85ff3d75fa 100644 --- a/lib/utils/include/utils/containers/binary_merge_disjoint_maps.h +++ b/lib/utils/include/utils/containers/binary_merge_disjoint_maps.h @@ -3,18 +3,20 @@ #include "utils/containers/binary_merge_maps_with.h" #include +#include "utils/containers/keys.h" +#include "utils/containers/intersection.h" namespace FlexFlow { template -std::unordered_map - binary_merge_disjoint_maps(std::unordered_map const &lhs, - std::unordered_map const &rhs) { +std::map + binary_merge_disjoint_maps(std::map const &lhs, + std::map const &rhs) { - std::unordered_set lhs_keys = unordered_keys(lhs); - std::unordered_set rhs_keys = unordered_keys(rhs); + std::set lhs_keys = keys(lhs); + std::set rhs_keys = keys(rhs); - std::unordered_set shared_keys = intersection(lhs_keys, rhs_keys); + std::set shared_keys = intersection(lhs_keys, rhs_keys); ASSERT(shared_keys.empty()); return binary_merge_maps_with( diff --git a/lib/utils/include/utils/containers/binary_merge_disjoint_unordered_maps.h b/lib/utils/include/utils/containers/binary_merge_disjoint_unordered_maps.h new file mode 100644 index 0000000000..536d402be0 --- /dev/null +++ b/lib/utils/include/utils/containers/binary_merge_disjoint_unordered_maps.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_BINARY_MERGE_DISJOINT_UNORDERED_MAPS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_BINARY_MERGE_DISJOINT_UNORDERED_MAPS_H + +#include +#include "utils/containers/binary_merge_unordered_maps_with.h" +#include "utils/containers/unordered_keys.h" +#include "utils/containers/intersection.h" + +namespace FlexFlow { + +template +std::unordered_map + binary_merge_disjoint_unordered_maps(std::unordered_map const &lhs, + std::unordered_map const &rhs) { + + std::unordered_set lhs_keys = unordered_keys(lhs); + std::unordered_set rhs_keys = unordered_keys(rhs); + + std::unordered_set shared_keys = intersection(lhs_keys, rhs_keys); + ASSERT(shared_keys.empty()); + + return binary_merge_unordered_maps_with( + lhs, rhs, [](V const &, V const &) -> V { PANIC(); }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/binary_merge_maps_with.h b/lib/utils/include/utils/containers/binary_merge_maps_with.h index 2d0b57eb81..3c3b556830 100644 --- a/lib/utils/include/utils/containers/binary_merge_maps_with.h +++ b/lib/utils/include/utils/containers/binary_merge_maps_with.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_BINARY_MERGE_MAPS_WITH_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_BINARY_MERGE_MAPS_WITH_H -#include "utils/containers/generate_unordered_map.h" +#include "utils/containers/generate_map.h" #include "utils/containers/intersection.h" -#include "utils/containers/unordered_keys.h" +#include "utils/containers/keys.h" #include "utils/containers/merge_maps_with_right_dominating.h" #include "utils/containers/restrict_keys.h" #include "utils/containers/set_minus.h" @@ -12,22 +12,22 @@ namespace FlexFlow { template -std::unordered_map - binary_merge_maps_with(std::unordered_map const &lhs, - std::unordered_map const &rhs, +std::map + binary_merge_maps_with(std::map const &lhs, + std::map const &rhs, F &&f) { - std::unordered_set l_keys = unordered_keys(lhs); - std::unordered_set r_keys = unordered_keys(rhs); + std::set l_keys = keys(lhs); + std::set r_keys = keys(rhs); - std::unordered_set l_only_keys = set_minus(l_keys, r_keys); - std::unordered_set r_only_keys = set_minus(r_keys, l_keys); - std::unordered_set both_keys = intersection(r_keys, l_keys); + std::set l_only_keys = set_minus(l_keys, r_keys); + std::set r_only_keys = set_minus(r_keys, l_keys); + std::set both_keys = intersection(r_keys, l_keys); - std::unordered_map l_only = restrict_keys(lhs, l_only_keys); - std::unordered_map r_only = restrict_keys(rhs, r_only_keys); + std::map l_only = restrict_keys(lhs, l_only_keys); + std::map r_only = restrict_keys(rhs, r_only_keys); - std::unordered_map merged = generate_unordered_map( + std::map merged = generate_map( both_keys, [&](K const &k) { return f(lhs.at(k), rhs.at(k)); }); return merge_maps_with_right_dominating(std::vector{ diff --git a/lib/utils/include/utils/containers/binary_merge_maps_with_left_dominating.h b/lib/utils/include/utils/containers/binary_merge_maps_with_left_dominating.h index f6e23af11c..25be62d9c3 100644 --- a/lib/utils/include/utils/containers/binary_merge_maps_with_left_dominating.h +++ b/lib/utils/include/utils/containers/binary_merge_maps_with_left_dominating.h @@ -6,9 +6,9 @@ namespace FlexFlow { template -std::unordered_map binary_merge_maps_with_left_dominating( - std::unordered_map const &lhs, std::unordered_map const &rhs) { - std::unordered_map result; +std::map binary_merge_maps_with_left_dominating( + std::map const &lhs, std::map const &rhs) { + std::map result; merge_in_map(rhs, result); merge_in_map(lhs, result); return result; diff --git a/lib/utils/include/utils/containers/binary_merge_maps_with_right_dominating.h b/lib/utils/include/utils/containers/binary_merge_maps_with_right_dominating.h index e5e29dfcb9..e4bfdd6d29 100644 --- a/lib/utils/include/utils/containers/binary_merge_maps_with_right_dominating.h +++ b/lib/utils/include/utils/containers/binary_merge_maps_with_right_dominating.h @@ -6,9 +6,9 @@ namespace FlexFlow { template -std::unordered_map binary_merge_maps_with_right_dominating( - std::unordered_map const &lhs, std::unordered_map const &rhs) { - std::unordered_map result; +std::map binary_merge_maps_with_right_dominating( + std::map const &lhs, std::map const &rhs) { + std::map result; merge_in_map(lhs, result); merge_in_map(rhs, result); return result; diff --git a/lib/utils/include/utils/containers/binary_merge_unordered_maps_with.h b/lib/utils/include/utils/containers/binary_merge_unordered_maps_with.h new file mode 100644 index 0000000000..2f8be45802 --- /dev/null +++ b/lib/utils/include/utils/containers/binary_merge_unordered_maps_with.h @@ -0,0 +1,42 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_BINARY_MERGE_UNORDERED_MAPS_WITH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_BINARY_MERGE_UNORDERED_MAPS_WITH_H + +#include "utils/containers/generate_unordered_map.h" +#include "utils/containers/intersection.h" +#include "utils/containers/unordered_keys.h" +#include "utils/containers/merge_unordered_maps_with_right_dominating.h" +#include "utils/containers/restrict_keys.h" +#include "utils/containers/set_minus.h" +#include + +namespace FlexFlow { + +template +std::unordered_map + binary_merge_unordered_maps_with(std::unordered_map const &lhs, + std::unordered_map const &rhs, + F &&f) { + + std::unordered_set l_keys = unordered_keys(lhs); + std::unordered_set r_keys = unordered_keys(rhs); + + std::unordered_set l_only_keys = set_minus(l_keys, r_keys); + std::unordered_set r_only_keys = set_minus(r_keys, l_keys); + std::unordered_set both_keys = intersection(r_keys, l_keys); + + std::unordered_map l_only = restrict_keys(lhs, l_only_keys); + std::unordered_map r_only = restrict_keys(rhs, r_only_keys); + + std::unordered_map merged = generate_unordered_map( + both_keys, [&](K const &k) { return f(lhs.at(k), rhs.at(k)); }); + + return merge_unordered_maps_with_right_dominating(std::vector{ + l_only, + r_only, + merged, + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/binary_merge_unordered_maps_with_left_dominating.h b/lib/utils/include/utils/containers/binary_merge_unordered_maps_with_left_dominating.h new file mode 100644 index 0000000000..0d71a1b7cc --- /dev/null +++ b/lib/utils/include/utils/containers/binary_merge_unordered_maps_with_left_dominating.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_BINARY_MERGE_UNORDERED_MAPS_WITH_LEFT_DOMINATING_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_BINARY_MERGE_UNORDERED_MAPS_WITH_LEFT_DOMINATING_H + +#include "utils/containers/merge_in_unordered_map.h" + +namespace FlexFlow { + +template +std::unordered_map binary_merge_unordered_maps_with_left_dominating( + std::unordered_map const &lhs, std::unordered_map const &rhs) { + std::unordered_map result; + merge_in_unordered_map(rhs, result); + merge_in_unordered_map(lhs, result); + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/binary_merge_unordered_maps_with_right_dominating.h b/lib/utils/include/utils/containers/binary_merge_unordered_maps_with_right_dominating.h new file mode 100644 index 0000000000..6a3581a635 --- /dev/null +++ b/lib/utils/include/utils/containers/binary_merge_unordered_maps_with_right_dominating.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_BINARY_MERGE_UNORDERED_MAPS_WITH_RIGHT_DOMINATING_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_BINARY_MERGE_UNORDERED_MAPS_WITH_RIGHT_DOMINATING_H + +#include "utils/containers/merge_in_unordered_map.h" + +namespace FlexFlow { + +template +std::unordered_map binary_merge_unordered_maps_with_right_dominating( + std::unordered_map const &lhs, std::unordered_map const &rhs) { + std::unordered_map result; + merge_in_unordered_map(lhs, result); + merge_in_unordered_map(rhs, result); + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/flatmap.h b/lib/utils/include/utils/containers/flatmap.h index 70de6b5020..e0440cc791 100644 --- a/lib/utils/include/utils/containers/flatmap.h +++ b/lib/utils/include/utils/containers/flatmap.h @@ -1,12 +1,12 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FLATMAP_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FLATMAP_H -#include "utils/containers/binary_merge_disjoint_maps.h" #include "utils/containers/extend.h" #include "utils/containers/get_element_type.h" #include #include #include +#include "utils/containers/binary_merge_disjoint_unordered_maps.h" namespace FlexFlow { @@ -77,7 +77,7 @@ std::unordered_map flatmap(std::unordered_map const &m, std::unordered_map result; for (auto const &[k, v] : m) { - result = binary_merge_disjoint_maps(result, f(k, v)); + result = binary_merge_disjoint_unordered_maps(result, f(k, v)); } return result; diff --git a/lib/utils/include/utils/containers/get_only.h b/lib/utils/include/utils/containers/get_only.h index ed44b26c36..d5e2faf67a 100644 --- a/lib/utils/include/utils/containers/get_only.h +++ b/lib/utils/include/utils/containers/get_only.h @@ -2,17 +2,15 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_ONLY_H #include "utils/containers/maybe_get_only.h" -#include "utils/exception.h" +#include #include "utils/optional.h" namespace FlexFlow { template typename C::value_type get_only(C const &c) { - return unwrap(maybe_get_only(c), [&] { - throw mk_runtime_error(fmt::format( - "Encountered container with size {} in get_only", c.size())); - }); + ASSERT(c.size() == 1); + return maybe_get_only(c).value(); } template diff --git a/lib/utils/include/utils/containers/map_from_pairs.h b/lib/utils/include/utils/containers/map_from_pairs.h index 7c470d4d3e..f5f2dad415 100644 --- a/lib/utils/include/utils/containers/map_from_pairs.h +++ b/lib/utils/include/utils/containers/map_from_pairs.h @@ -1,18 +1,15 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAP_FROM_PAIRS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAP_FROM_PAIRS_H -#include -#include +#include namespace FlexFlow { -template -std::unordered_map - map_from_pairs(std::unordered_set> const &pairs) { - - std::unordered_map result(pairs.cbegin(), pairs.cend()); - - return result; +template +std::map map_from_pairs(C const &c) { + return std::map(c.cbegin(), c.cend()); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/containers/map_keys.h b/lib/utils/include/utils/containers/map_keys.h index 5cd44d8a5d..ff41248a30 100644 --- a/lib/utils/include/utils/containers/map_keys.h +++ b/lib/utils/include/utils/containers/map_keys.h @@ -2,11 +2,13 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAP_KEYS_H #include "utils/containers/keys.h" +#include "utils/containers/unordered_keys.h" #include "utils/containers/transform.h" #include "utils/containers/unordered_multiset_of.h" #include "utils/exception.h" #include #include +#include namespace FlexFlow { @@ -19,14 +21,35 @@ template > std::unordered_map map_keys(std::unordered_map const &m, - F const &f) { + F &&f) { std::unordered_map result; for (auto const &kv : m) { result.insert({f(kv.first), kv.second}); } - ASSERT(keys(m).size() == keys(result).size(), + ASSERT(m.size() == result.size(), + "keys passed to map_keys must be transformed into distinct keys"); + + return result; +} + +/** + * @brief Applies the given function to all the keys within the given map and + * returns the updated map. + */ +template > +std::map map_keys(std::map const &m, F &&f) { + + std::map result; + for (auto const &kv : m) { + result.insert({f(kv.first), kv.second}); + } + + ASSERT(m.size() == result.size(), "keys passed to map_keys must be transformed into distinct keys"); return result; diff --git a/lib/utils/include/utils/containers/map_values2.h b/lib/utils/include/utils/containers/map_values2.h index 752a8babd3..dd943b02bb 100644 --- a/lib/utils/include/utils/containers/map_values2.h +++ b/lib/utils/include/utils/containers/map_values2.h @@ -3,6 +3,7 @@ #include #include +#include namespace FlexFlow { @@ -19,6 +20,20 @@ std::unordered_map map_values2(std::unordered_map const &m, return result; } +template > +std::map map_values2(std::map const &m, + F &&f) { + std::map result; + for (std::pair const &kv : m) { + result.insert(std::pair{kv.first, f(kv.first, kv.second)}); + } + return result; +} + + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/merge_disjoint_maps.h b/lib/utils/include/utils/containers/merge_disjoint_maps.h index eccb06180a..b541fdbd53 100644 --- a/lib/utils/include/utils/containers/merge_disjoint_maps.h +++ b/lib/utils/include/utils/containers/merge_disjoint_maps.h @@ -9,12 +9,12 @@ namespace FlexFlow { template -std::unordered_map merge_disjoint_maps(C const &c) { - std::unordered_map empty = {}; +std::map merge_disjoint_maps(C const &c) { + std::map empty = {}; return foldl(c, /*init=*/empty, - [](std::unordered_map const &lhs, - std::unordered_map const &rhs) { + [](std::map const &lhs, + std::map const &rhs) { return binary_merge_disjoint_maps(lhs, rhs); }); } diff --git a/lib/utils/include/utils/containers/merge_disjoint_unordered_maps.h b/lib/utils/include/utils/containers/merge_disjoint_unordered_maps.h new file mode 100644 index 0000000000..1bd7fb019b --- /dev/null +++ b/lib/utils/include/utils/containers/merge_disjoint_unordered_maps.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MERGE_DISJOINT_UNORDERED_MAPS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MERGE_DISJOINT_UNORDERED_MAPS_H + +#include "utils/containers/foldl.h" +#include "utils/containers/binary_merge_disjoint_unordered_maps.h" + +namespace FlexFlow { + +template +std::unordered_map merge_disjoint_unordered_maps(C const &c) { + std::unordered_map empty = {}; + return foldl(c, + /*init=*/empty, + [](std::unordered_map const &lhs, + std::unordered_map const &rhs) { + return binary_merge_disjoint_unordered_maps(lhs, rhs); + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/merge_in_map.h b/lib/utils/include/utils/containers/merge_in_map.h index edae4b8a6a..e41c1a6826 100644 --- a/lib/utils/include/utils/containers/merge_in_map.h +++ b/lib/utils/include/utils/containers/merge_in_map.h @@ -1,13 +1,13 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MERGE_IN_MAP_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MERGE_IN_MAP_H -#include +#include namespace FlexFlow { template -void merge_in_map(std::unordered_map const &m, - std::unordered_map &result) { +void merge_in_map(std::map const &m, + std::map &result) { for (auto const &[k, v] : m) { auto it = result.find(k); if (it != result.end()) { diff --git a/lib/utils/include/utils/containers/merge_in_unordered_map.h b/lib/utils/include/utils/containers/merge_in_unordered_map.h new file mode 100644 index 0000000000..7c2b31b8fc --- /dev/null +++ b/lib/utils/include/utils/containers/merge_in_unordered_map.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MERGE_IN_UNORDERED_MAPS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MERGE_IN_UNORDERED_MAPS_H + +#include + +namespace FlexFlow { + +template +void merge_in_unordered_map(std::unordered_map const &m, + std::unordered_map &result) { + for (auto const &[k, v] : m) { + auto it = result.find(k); + if (it != result.end()) { + it->second = v; + } else { + result.insert({k, v}); + } + } +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/merge_maps_with.h b/lib/utils/include/utils/containers/merge_maps_with.h index 2f5a09e26e..eef6c9af67 100644 --- a/lib/utils/include/utils/containers/merge_maps_with.h +++ b/lib/utils/include/utils/containers/merge_maps_with.h @@ -9,13 +9,13 @@ namespace FlexFlow { template -std::unordered_map - merge_maps_with(std::vector> const &to_merge, +std::map + merge_maps_with(std::vector> const &to_merge, F &&f) { return foldl(to_merge, - std::unordered_map{}, - [&](std::unordered_map const &accum, - std::unordered_map const &m) { + std::map{}, + [&](std::map const &accum, + std::map const &m) { return binary_merge_maps_with(accum, m, f); }); } diff --git a/lib/utils/include/utils/containers/merge_maps_with_right_dominating.h b/lib/utils/include/utils/containers/merge_maps_with_right_dominating.h index 1d4f8536d8..6271cef2c0 100644 --- a/lib/utils/include/utils/containers/merge_maps_with_right_dominating.h +++ b/lib/utils/include/utils/containers/merge_maps_with_right_dominating.h @@ -8,10 +8,10 @@ namespace FlexFlow { template -std::unordered_map merge_maps_with_right_dominating(C const &c) { - std::unordered_map result; +std::map merge_maps_with_right_dominating(C const &c) { + std::map result; - for (std::unordered_map const &m : c) { + for (std::map const &m : c) { merge_in_map(m, result); } diff --git a/lib/utils/include/utils/containers/merge_unordered_maps_with.h b/lib/utils/include/utils/containers/merge_unordered_maps_with.h new file mode 100644 index 0000000000..fee7fa2fa4 --- /dev/null +++ b/lib/utils/include/utils/containers/merge_unordered_maps_with.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MERGE_UNORDERED_MAPS_WITH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MERGE_UNORDERED_MAPS_WITH_H + +#include "utils/containers/binary_merge_unordered_maps_with.h" +#include "utils/containers/foldl.h" +#include +#include + +namespace FlexFlow { + +template +std::unordered_map + merge_unordered_maps_with(std::vector> const &to_merge, + F &&f) { + return foldl(to_merge, + std::unordered_map{}, + [&](std::unordered_map const &accum, + std::unordered_map const &m) { + return binary_merge_unordered_maps_with(accum, m, f); + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/merge_unordered_maps_with_right_dominating.h b/lib/utils/include/utils/containers/merge_unordered_maps_with_right_dominating.h new file mode 100644 index 0000000000..1323378019 --- /dev/null +++ b/lib/utils/include/utils/containers/merge_unordered_maps_with_right_dominating.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MERGE_UNORDERED_MAPS_WITH_RIGHT_DOMINATING_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MERGE_UNORDERED_MAPS_WITH_RIGHT_DOMINATING_H + +#include "utils/containers/merge_in_unordered_map.h" + +namespace FlexFlow { + +template +std::unordered_map merge_unordered_maps_with_right_dominating(C const &c) { + std::unordered_map result; + + for (std::unordered_map const &m : c) { + merge_in_unordered_map(m, result); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/restrict_keys.h b/lib/utils/include/utils/containers/restrict_keys.h index bedcc4ed8e..353b2ec237 100644 --- a/lib/utils/include/utils/containers/restrict_keys.h +++ b/lib/utils/include/utils/containers/restrict_keys.h @@ -4,6 +4,8 @@ #include "utils/containers/contains.h" #include #include +#include +#include namespace FlexFlow { @@ -19,6 +21,17 @@ std::unordered_map restrict_keys(std::unordered_map const &m, return result; } +template +std::map restrict_keys(std::map const &m, + std::set const &mask) { + std::map result; + for (auto const &kv : m) { + if (contains(mask, kv.first)) { + result.insert(kv); + } + } + return result; +} } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/zip_values_strict.h b/lib/utils/include/utils/containers/zip_values_strict.h index 1a3ce95eb1..b891490e39 100644 --- a/lib/utils/include/utils/containers/zip_values_strict.h +++ b/lib/utils/include/utils/containers/zip_values_strict.h @@ -6,6 +6,9 @@ #include "utils/containers/require_same.h" #include #include +#include +#include "utils/containers/keys.h" +#include "utils/containers/generate_map.h" namespace FlexFlow { @@ -24,6 +27,21 @@ std::unordered_map> }); } +template +std::map> + zip_values_strict(std::map const &m1, + std::map const &m2) { + + ASSERT(keys(m1) == keys(m2)); + + return generate_map(require_same(keys(m1), keys(m2)), [&](K const &k) { + return std::pair{ + m1.at(k), + m2.at(k), + }; + }); +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/full_binary_tree/get_path_to_leaf_map.h b/lib/utils/include/utils/full_binary_tree/get_path_to_leaf_map.h index fd77509e4d..e3947605c0 100644 --- a/lib/utils/include/utils/full_binary_tree/get_path_to_leaf_map.h +++ b/lib/utils/include/utils/full_binary_tree/get_path_to_leaf_map.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_PATH_TO_LEAF_MAP_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_PATH_TO_LEAF_MAP_H -#include "utils/containers/binary_merge_disjoint_maps.h" +#include "utils/containers/binary_merge_disjoint_unordered_maps.h" #include "utils/containers/map_keys.h" #include "utils/containers/multiset_union.h" #include "utils/full_binary_tree/binary_tree_path.dtg.h" @@ -30,7 +30,7 @@ std::unordered_map get_path_to_leaf_map( get_path_to_leaf_map(impl.get_right_child(parent), impl), [](BinaryTreePath const &p) { return nest_inside_right_child(p); }); - return binary_merge_disjoint_maps(left_map, right_map); + return binary_merge_disjoint_unordered_maps(left_map, right_map); }, [](Leaf const &leaf) -> std::unordered_map { return std::unordered_map{ diff --git a/lib/utils/include/utils/nonempty_set/nonempty_set.h b/lib/utils/include/utils/nonempty_set/nonempty_set.h index 93da743592..fe4b152bd5 100644 --- a/lib/utils/include/utils/nonempty_set/nonempty_set.h +++ b/lib/utils/include/utils/nonempty_set/nonempty_set.h @@ -8,6 +8,8 @@ #include "utils/fmt/set.h" #include "utils/positive_int/positive_int.h" #include "utils/containers/unordered_set_of.h" +#include "utils/json/check_is_json_deserializable.h" +#include "utils/json/check_is_json_serializable.h" namespace FlexFlow { @@ -76,21 +78,24 @@ struct nonempty_set { return unordered_set_of(this->raw); } + using const_iterator = typename std::set::const_iterator; using value_type = T; + using reference = value_type &; + using const_reference = value_type const &; - typename std::set::const_iterator begin() const { + const_iterator begin() const { return this->raw.cbegin(); } - typename std::set::const_iterator cbegin() const { + const_iterator cbegin() const { return this->raw.cbegin(); } - typename std::set::const_iterator end() const { + const_iterator end() const { return this->raw.cend(); } - typename std::set::const_iterator cend() const { + const_iterator cend() const { return this->raw.cend(); } @@ -122,6 +127,27 @@ std::ostream &operator<<(std::ostream &s, nonempty_set const &m) { } // namespace FlexFlow +namespace nlohmann { + +template +struct adl_serializer<::FlexFlow::nonempty_set> { + static ::FlexFlow::nonempty_set from_json(json const &j) { + CHECK_IS_JSON_DESERIALIZABLE(T); + + std::set s = j; + + return ::FlexFlow::nonempty_set{s}; + } + + static void to_json(json &j, ::FlexFlow::nonempty_set const &s) { + CHECK_IS_JSON_SERIALIZABLE(T); + + j = s.unwrap_as_set(); + } +}; + +} // namespace nlohmann + namespace std { template diff --git a/lib/utils/include/utils/one_to_many/require_one_to_many_is_bijection.h b/lib/utils/include/utils/one_to_many/require_one_to_many_is_bijection.h new file mode 100644 index 0000000000..9d31dc968c --- /dev/null +++ b/lib/utils/include/utils/one_to_many/require_one_to_many_is_bijection.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_REQUIRE_ONE_TO_MANY_IS_BIJECTION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_REQUIRE_ONE_TO_MANY_IS_BIJECTION_H + +#include "utils/bidict/algorithms/bidict_from_map.h" +#include "utils/containers/map_values.h" +#include "utils/containers/get_only.h" +#include "utils/one_to_many/one_to_many.h" +#include "utils/nonempty_set/nonempty_set.h" + +namespace FlexFlow { + +template +bidict require_one_to_many_is_bijection(OneToMany const &otm) { + return bidict_from_map( + map_values(otm.l_to_r(), + [](nonempty_set const &s) -> R { + return get_only(s.unwrap_as_set()); + })); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/orthotope/minimal_dim_domain.h b/lib/utils/include/utils/orthotope/minimal_dim_domain.h index c9d1214278..f9bf5fa979 100644 --- a/lib/utils/include/utils/orthotope/minimal_dim_domain.h +++ b/lib/utils/include/utils/orthotope/minimal_dim_domain.h @@ -2,7 +2,6 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_MINIMAL_DIM_DOMAIN_H #include "utils/containers/are_disjoint.h" -#include "utils/containers/binary_merge_disjoint_maps.h" #include "utils/containers/filtermap_values.h" #include "utils/containers/generate_unordered_map.h" #include "utils/containers/map_from_keys_and_values.h" @@ -16,6 +15,7 @@ #include "utils/orthotope/minimal_dim_domain.dtg.h" #include "utils/orthotope/minimal_orthotope.dtg.h" #include "utils/containers/unordered_keys.h" +#include "utils/containers/binary_merge_disjoint_unordered_maps.h" namespace FlexFlow { @@ -66,7 +66,7 @@ DimDomain dim_domain_from_minimal_dim_domain( ASSERT(are_disjoint(nontrivial_dims, trivial_dims)); return DimDomain{ - /*dims=*/binary_merge_disjoint_maps( + /*dims=*/binary_merge_disjoint_unordered_maps( map_values( minimal_dim_domain.dims, [](int_ge_two x) { return x.positive_int_from_int_ge_two(); }), diff --git a/lib/utils/src/utils/containers/binary_merge_disjoint_maps.cc b/lib/utils/src/utils/containers/binary_merge_disjoint_maps.cc index 0569b3ed0b..95d19083ac 100644 --- a/lib/utils/src/utils/containers/binary_merge_disjoint_maps.cc +++ b/lib/utils/src/utils/containers/binary_merge_disjoint_maps.cc @@ -1,13 +1,14 @@ #include "utils/containers/binary_merge_disjoint_maps.h" #include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using K = value_type<0>; +using K = ordered_value_type<0>; using V = value_type<1>; -template std::unordered_map - binary_merge_disjoint_maps(std::unordered_map const &, - std::unordered_map const &); +template std::map + binary_merge_disjoint_maps(std::map const &, + std::map const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/binary_merge_disjoint_unordered_maps.cc b/lib/utils/src/utils/containers/binary_merge_disjoint_unordered_maps.cc new file mode 100644 index 0000000000..60ccf3a7e5 --- /dev/null +++ b/lib/utils/src/utils/containers/binary_merge_disjoint_unordered_maps.cc @@ -0,0 +1,13 @@ +#include "utils/containers/binary_merge_disjoint_unordered_maps.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; + +template std::unordered_map + binary_merge_disjoint_unordered_maps(std::unordered_map const &, + std::unordered_map const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/binary_merge_maps_with.cc b/lib/utils/src/utils/containers/binary_merge_maps_with.cc index 35f771f60c..4679d21227 100644 --- a/lib/utils/src/utils/containers/binary_merge_maps_with.cc +++ b/lib/utils/src/utils/containers/binary_merge_maps_with.cc @@ -1,13 +1,14 @@ #include "utils/containers/binary_merge_maps_with.h" #include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using K = value_type<0>; +using K = ordered_value_type<0>; using V = value_type<1>; using F = std::function; -template std::unordered_map binary_merge_maps_with( - std::unordered_map const &, std::unordered_map const &, F &&); +template std::map binary_merge_maps_with( + std::map const &, std::map const &, F &&); } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/binary_merge_maps_with_left_dominating.cc b/lib/utils/src/utils/containers/binary_merge_maps_with_left_dominating.cc index c459e82061..d5b4f6cebe 100644 --- a/lib/utils/src/utils/containers/binary_merge_maps_with_left_dominating.cc +++ b/lib/utils/src/utils/containers/binary_merge_maps_with_left_dominating.cc @@ -1,13 +1,14 @@ #include "utils/containers/binary_merge_maps_with_left_dominating.h" #include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using K = value_type<0>; +using K = ordered_value_type<0>; using V = value_type<1>; -template std::unordered_map - binary_merge_maps_with_left_dominating(std::unordered_map const &, - std::unordered_map const &); +template std::map + binary_merge_maps_with_left_dominating(std::map const &, + std::map const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/binary_merge_maps_with_right_dominating.cc b/lib/utils/src/utils/containers/binary_merge_maps_with_right_dominating.cc index df934387d2..bbc799150e 100644 --- a/lib/utils/src/utils/containers/binary_merge_maps_with_right_dominating.cc +++ b/lib/utils/src/utils/containers/binary_merge_maps_with_right_dominating.cc @@ -1,13 +1,14 @@ #include "utils/containers/binary_merge_maps_with_right_dominating.h" #include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using K = value_type<0>; +using K = ordered_value_type<0>; using V = value_type<1>; -template std::unordered_map - binary_merge_maps_with_right_dominating(std::unordered_map const &, - std::unordered_map const &); +template std::map + binary_merge_maps_with_right_dominating(std::map const &, + std::map const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/binary_merge_unordered_maps_with.cc b/lib/utils/src/utils/containers/binary_merge_unordered_maps_with.cc new file mode 100644 index 0000000000..de8c484d0c --- /dev/null +++ b/lib/utils/src/utils/containers/binary_merge_unordered_maps_with.cc @@ -0,0 +1,13 @@ +#include "utils/containers/binary_merge_unordered_maps_with.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; +using F = std::function; + +template std::unordered_map binary_merge_unordered_maps_with( + std::unordered_map const &, std::unordered_map const &, F &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/binary_merge_unordered_maps_with_left_dominating.cc b/lib/utils/src/utils/containers/binary_merge_unordered_maps_with_left_dominating.cc new file mode 100644 index 0000000000..d777eb0e29 --- /dev/null +++ b/lib/utils/src/utils/containers/binary_merge_unordered_maps_with_left_dominating.cc @@ -0,0 +1,13 @@ +#include "utils/containers/binary_merge_unordered_maps_with_left_dominating.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; + +template std::unordered_map + binary_merge_unordered_maps_with_left_dominating(std::unordered_map const &, + std::unordered_map const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/binary_merge_unordered_maps_with_right_dominating.cc b/lib/utils/src/utils/containers/binary_merge_unordered_maps_with_right_dominating.cc new file mode 100644 index 0000000000..f5586cec6b --- /dev/null +++ b/lib/utils/src/utils/containers/binary_merge_unordered_maps_with_right_dominating.cc @@ -0,0 +1,13 @@ +#include "utils/containers/binary_merge_unordered_maps_with_right_dominating.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; + +template + std::unordered_map binary_merge_unordered_maps_with_right_dominating( + std::unordered_map const &, std::unordered_map const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/map_from_pairs.cc b/lib/utils/src/utils/containers/map_from_pairs.cc index ba0eed8c15..8dc8ffa29c 100644 --- a/lib/utils/src/utils/containers/map_from_pairs.cc +++ b/lib/utils/src/utils/containers/map_from_pairs.cc @@ -1,12 +1,21 @@ #include "utils/containers/map_from_pairs.h" -#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" +#include +#include +#include namespace FlexFlow { -using K = value_type<0>; -using V = value_type<1>; +using K = ordered_value_type<0>; +using V = ordered_value_type<1>; -template std::unordered_map +template std::map + map_from_pairs(std::set> const &); + +template std::map map_from_pairs(std::unordered_set> const &); +template std::map + map_from_pairs(std::vector> const &); + } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/map_keys.cc b/lib/utils/src/utils/containers/map_keys.cc index 7473c7e16d..daf9ed25d4 100644 --- a/lib/utils/src/utils/containers/map_keys.cc +++ b/lib/utils/src/utils/containers/map_keys.cc @@ -1 +1,22 @@ #include "utils/containers/map_keys.h" +#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using VT0 = value_type<0>; +using VT1 = value_type<1>; +using VT2 = value_type<2>; + +template + std::unordered_map map_keys(std::unordered_map const &, + std::function &&); + +using OV0 = ordered_value_type<0>; +using OV1 = ordered_value_type<1>; + +template + std::map map_keys(std::map const &m, + std::function &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/map_values2.cc b/lib/utils/src/utils/containers/map_values2.cc index 6aba8f4db0..8840f15aee 100644 --- a/lib/utils/src/utils/containers/map_values2.cc +++ b/lib/utils/src/utils/containers/map_values2.cc @@ -1,14 +1,21 @@ #include "utils/containers/map_values2.h" #include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using K = value_type<0>; -using V = value_type<1>; -using V2 = value_type<2>; -using F = std::function; +using VT0 = value_type<0>; +using VT1 = value_type<1>; +using VT2 = value_type<2>; -template std::unordered_map map_values2(std::unordered_map const &, - F &&); +template std::unordered_map map_values2( + std::unordered_map const &, + std::function &&); + +using OT0 = ordered_value_type<0>; + +template std::map map_values2( + std::map const &, + std::function &&); } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/merge_disjoint_maps.cc b/lib/utils/src/utils/containers/merge_disjoint_maps.cc index dec8ee0618..e810b6b4f0 100644 --- a/lib/utils/src/utils/containers/merge_disjoint_maps.cc +++ b/lib/utils/src/utils/containers/merge_disjoint_maps.cc @@ -1,12 +1,13 @@ #include "utils/containers/merge_disjoint_maps.h" #include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using K = value_type<0>; +using K = ordered_value_type<0>; using V = value_type<1>; -using C = std::vector>; +using C = std::vector>; -template std::unordered_map merge_disjoint_maps(C const &); +template std::map merge_disjoint_maps(C const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/merge_disjoint_unordered_maps.cc b/lib/utils/src/utils/containers/merge_disjoint_unordered_maps.cc new file mode 100644 index 0000000000..1ef6af0877 --- /dev/null +++ b/lib/utils/src/utils/containers/merge_disjoint_unordered_maps.cc @@ -0,0 +1,12 @@ +#include "utils/containers/merge_disjoint_unordered_maps.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; +using C = std::vector>; + +template std::unordered_map merge_disjoint_unordered_maps(C const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/merge_in_map.cc b/lib/utils/src/utils/containers/merge_in_map.cc index ada1a803ad..618128ff3b 100644 --- a/lib/utils/src/utils/containers/merge_in_map.cc +++ b/lib/utils/src/utils/containers/merge_in_map.cc @@ -1,12 +1,12 @@ #include "utils/containers/merge_in_map.h" -#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using K = value_type<0>; -using V = value_type<1>; +using K = ordered_value_type<0>; +using V = ordered_value_type<1>; -template void merge_in_map(std::unordered_map const &, - std::unordered_map &); +template void merge_in_map(std::map const &, + std::map &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/merge_in_unordered_map.cc b/lib/utils/src/utils/containers/merge_in_unordered_map.cc new file mode 100644 index 0000000000..95228dc5f2 --- /dev/null +++ b/lib/utils/src/utils/containers/merge_in_unordered_map.cc @@ -0,0 +1,12 @@ +#include "utils/containers/merge_in_unordered_map.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; + +template void merge_in_unordered_map(std::unordered_map const &, + std::unordered_map &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/merge_maps_with.cc b/lib/utils/src/utils/containers/merge_maps_with.cc index 0375b16bc4..b5b471a1f8 100644 --- a/lib/utils/src/utils/containers/merge_maps_with.cc +++ b/lib/utils/src/utils/containers/merge_maps_with.cc @@ -1,13 +1,14 @@ #include "utils/containers/merge_maps_with.h" #include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using K = value_type<0>; +using K = ordered_value_type<0>; using V = value_type<1>; using F = std::function; -template std::unordered_map - merge_maps_with(std::vector> const &, F &&); +template std::map + merge_maps_with(std::vector> const &, F &&); } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/merge_maps_with_right_dominating.cc b/lib/utils/src/utils/containers/merge_maps_with_right_dominating.cc index d8c269d6e9..f33b46780c 100644 --- a/lib/utils/src/utils/containers/merge_maps_with_right_dominating.cc +++ b/lib/utils/src/utils/containers/merge_maps_with_right_dominating.cc @@ -1,12 +1,13 @@ #include "utils/containers/merge_maps_with_right_dominating.h" #include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using K = value_type<0>; +using K = ordered_value_type<0>; using V = value_type<1>; -using C = std::vector>; +using C = std::vector>; -template std::unordered_map merge_maps_with_right_dominating(C const &); +template std::map merge_maps_with_right_dominating(C const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/merge_unordered_maps_with.cc b/lib/utils/src/utils/containers/merge_unordered_maps_with.cc new file mode 100644 index 0000000000..60218312f3 --- /dev/null +++ b/lib/utils/src/utils/containers/merge_unordered_maps_with.cc @@ -0,0 +1,14 @@ +#include "utils/containers/merge_unordered_maps_with.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; +using F = std::function; + +std::unordered_map + merge_unordered_maps_with(std::vector> const &, + F &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/merge_unordered_maps_with_right_dominating.cc b/lib/utils/src/utils/containers/merge_unordered_maps_with_right_dominating.cc new file mode 100644 index 0000000000..1dd7da70a3 --- /dev/null +++ b/lib/utils/src/utils/containers/merge_unordered_maps_with_right_dominating.cc @@ -0,0 +1,12 @@ +#include "utils/containers/merge_unordered_maps_with_right_dominating.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; +using C = std::vector>; + +template std::unordered_map merge_unordered_maps_with_right_dominating(C const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/restrict_keys.cc b/lib/utils/src/utils/containers/restrict_keys.cc index d2749b7ea2..13584abec1 100644 --- a/lib/utils/src/utils/containers/restrict_keys.cc +++ b/lib/utils/src/utils/containers/restrict_keys.cc @@ -1 +1,21 @@ #include "utils/containers/restrict_keys.h" +#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using VT0 = value_type<0>; +using VT1 = value_type<1>; + +template + std::unordered_map restrict_keys(std::unordered_map const &, + std::unordered_set const &); + +using OV0 = ordered_value_type<0>; + +template + std::map restrict_keys(std::map const &, + std::set const &); + + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/zip_values_strict.cc b/lib/utils/src/utils/containers/zip_values_strict.cc index b9bed29a1b..c1710129a5 100644 --- a/lib/utils/src/utils/containers/zip_values_strict.cc +++ b/lib/utils/src/utils/containers/zip_values_strict.cc @@ -1,14 +1,23 @@ #include "utils/containers/zip_values_strict.h" #include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using K = value_type<0>; -using V1 = value_type<1>; -using V2 = value_type<2>; +using VT0 = value_type<0>; +using VT1 = value_type<1>; +using VT2 = value_type<2>; + +template std::unordered_map> + zip_values_strict(std::unordered_map const &, + std::unordered_map const &); + +using OV0 = ordered_value_type<0>; + +template std::map> + zip_values_strict(std::map const &, + std::map const &); + -template std::unordered_map> - zip_values_strict(std::unordered_map const &, - std::unordered_map const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/nonempty_set/nonempty_set.cc b/lib/utils/src/utils/nonempty_set/nonempty_set.cc index 1af2951f10..a7bd7f7c5a 100644 --- a/lib/utils/src/utils/nonempty_set/nonempty_set.cc +++ b/lib/utils/src/utils/nonempty_set/nonempty_set.cc @@ -1,7 +1,9 @@ #include "utils/nonempty_set/nonempty_set.h" #include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/jsonable_ordered_value_type.h" using T = ::FlexFlow::ordered_value_type<0>; +using J = ::FlexFlow::jsonable_ordered_value_type<0>; namespace FlexFlow { @@ -16,6 +18,12 @@ template std::ostream &operator<<(std::ostream &, nonempty_set const &); } // namespace FlexFlow +namespace nlohmann { + +template struct adl_serializer<::FlexFlow::nonempty_set>; + +} // namespace nlohmann + namespace std { template struct hash<::FlexFlow::nonempty_set>; diff --git a/lib/utils/src/utils/one_to_many/require_one_to_many_is_bijection.cc b/lib/utils/src/utils/one_to_many/require_one_to_many_is_bijection.cc new file mode 100644 index 0000000000..60a78bb123 --- /dev/null +++ b/lib/utils/src/utils/one_to_many/require_one_to_many_is_bijection.cc @@ -0,0 +1,12 @@ +#include "utils/one_to_many/require_one_to_many_is_bijection.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using L = ordered_value_type<0>; +using R = ordered_value_type<1>; + +template + bidict require_one_to_many_is_bijection(OneToMany const &); + +} // namespace FlexFlow From 9bea615de80e2ce0d1fb7f5b9c397b274ad36c5a Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Fri, 29 May 2026 03:19:00 -0700 Subject: [PATCH 18/19] Pass shard expansion test for bwd replicate --- .../dynamic_graph/shard_expansion.cc | 129 +++++++- .../dynamic_graph/shard_expansion.cc | 312 ++++++++++++------ lib/utils/include/utils/bidict/bidict.h | 3 + 3 files changed, 343 insertions(+), 101 deletions(-) diff --git a/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc b/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc index db9484e361..f910c6c6c2 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc @@ -226,6 +226,10 @@ static std::set }); } +// TODO(@lockshaw): There is a lot of code duplication between +// generate_shard_expansion_for_fwd_replicate and +// generate_shard_expansion_for_bwd_replicate that should eventually be +// factored out. static std::set generate_shard_expansion_for_fwd_replicate(DynamicNodeInvocation const &i) { ASSERT(i.node_attrs.task_type == DynamicTaskType::FWD); @@ -341,6 +345,121 @@ static std::set return transform(input_tensor_shards, invocation_sharding_info_for_input_tensor_shard); } +static std::set + generate_shard_expansion_for_bwd_replicate(DynamicNodeInvocation const &i) { + ASSERT(i.node_attrs.task_type == DynamicTaskType::BWD); + + MappedOperatorTaskGroup node_mapping = assert_unwrap(i.node_attrs.mapping); + + DynamicTensorSlot expected_output_grad_slot = DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::OUTPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_bwd(), + /*task_shard=*/std::nullopt, + }; + + DynamicValueAttrs output_grad = require_only_key(i.inputs, expected_output_grad_slot); + + DynamicTensorSlot expected_input_grad_slot = DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::INPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_bwd(), + /*task_shard=*/std::nullopt, + }; + + DynamicValueAttrs input_grad = require_only_key(i.outputs, expected_input_grad_slot); + + bidict + output_grad_value_mapping = require_one_to_many_is_bijection( + assert_unwrap(output_grad.mapping)); + + bidict + input_grad_value_mapping = require_one_to_many_is_bijection( + assert_unwrap(input_grad.mapping)); + + std::set input_grad_tensor_shards = set_of(input_grad_value_mapping.left_values()); + + auto get_task_shard_machine_coords_for_input_grad_tensor_shard + = [&](ParallelTensorSpaceCoordinate const &input_grad_tensor_shard) + -> nonempty_set + { + bidict produce_input_grad_tensor_shard + = bidict_filter_values( + node_mapping.get_shard_bindings(), + [&](OperatorAtomicTaskShardBinding const &b) -> bool { + return ptensor_space_coord_for_slot_name(b, TensorSlotName::INPUT) == input_grad_tensor_shard; + }); + + return nonempty_set(set_of(produce_input_grad_tensor_shard.left_values())); + }; + + auto invocation_sharding_info_for_input_grad_tensor_shard = [&](ParallelTensorSpaceCoordinate const &c) + -> DynamicNodeInvocationShardingInfo + { + nonempty_set task_shard_machine_coords = + get_task_shard_machine_coords_for_input_grad_tensor_shard(c); + + std::map output_grad_sharding_infos = + generate_map(task_shard_machine_coords.unwrap_as_set(), + [&](MachineSpaceCoordinate const &mc) + -> DynamicValueAttrsShardingInfo + { + ParallelTensorSpaceCoordinate pc = output_grad_value_mapping.at_r(mc); + + return DynamicValueAttrsShardingInfo{ + /*shard_coord=*/pc, + /*mapping=*/OneToMany{ + { + pc, + {mc}, + }, + }, + }; + }); + + std::map keyed_output_grad_sharding_infos = + map_keys(output_grad_sharding_infos, + [&](MachineSpaceCoordinate const &mc) -> DynamicTensorSlot { + return DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::OUTPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_bwd(), + /*task_shard=*/mc, + }; + }); + + DynamicTensorSlot input_grad_slot = DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::INPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_bwd(), + /*task_shard=*/std::nullopt, + }; + + DynamicValueAttrsShardingInfo input_grad_sharding_info = DynamicValueAttrsShardingInfo{ + /*shard_coord=*/c, + /*mapping=*/OneToMany{ + { + c, + {input_grad_value_mapping.at_l(c)}, + }, + }, + }; + + std::map sharding_infos = + binary_merge_disjoint_maps( + keyed_output_grad_sharding_infos, + std::map{ + { + input_grad_slot, + input_grad_sharding_info, + }, + }); + + return DynamicNodeInvocationShardingInfo{ + /*device_coords=*/task_shard_machine_coords, + /*value_sharding=*/sharding_infos, + }; + }; + + return transform(input_grad_tensor_shards, invocation_sharding_info_for_input_grad_tensor_shard); +} + std::unordered_set perform_shard_expansion_for_invocation(DynamicNodeInvocation const &i) { @@ -433,7 +552,15 @@ std::unordered_set } if (training_op_attrs_has_op_type(i.node_attrs.op_attrs.value(), OperatorType::REPLICATE)) { - return unordered_set_of(generate_shard_expansion_for_fwd_replicate(i)); + DynamicTaskType task_type = assert_unwrap(i.node_attrs.task_type); + switch (task_type) { + case DynamicTaskType::FWD: + return unordered_set_of(generate_shard_expansion_for_fwd_replicate(i)); + case DynamicTaskType::BWD: + return unordered_set_of(generate_shard_expansion_for_bwd_replicate(i)); + default: + PANIC("Unexpected task type for Replicate: {}", task_type); + } } MappedOperatorTaskGroup mapping = assert_unwrap(i.node_attrs.mapping); diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/shard_expansion.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/shard_expansion.cc index 9204e60c58..7f23053943 100644 --- a/lib/task-spec/test/src/task-spec/dynamic_graph/shard_expansion.cc +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/shard_expansion.cc @@ -379,18 +379,6 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelTensorSpaceCoordinate pt3 = mk_pt_coord(0_n, 1_n, 0_n, 0_n); ParallelTensorSpaceCoordinate pt4 = mk_pt_coord(0_n, 1_n, 0_n, 1_n); - OneToMany src_binding{ - {pt1, {mc1}}, - {pt2, {mc2}}, - }; - - OneToMany dst_binding{ - {pt1, {mc1}}, - {pt2, {mc2}}, - {pt3, {mc3}}, - {pt4, {mc4}}, - }; - auto mk_shard_binding = [&](ParallelTensorSpaceCoordinate const &c1, ParallelTensorSpaceCoordinate const &c2) -> OperatorAtomicTaskShardBinding { @@ -429,110 +417,234 @@ TEST_SUITE(FF_TEST_SUITE) { }, }; - DynamicNodeInvocation input = DynamicNodeInvocation{ - /*inputs=*/{ - { - DynamicTensorSlot{ - /*slot_name=*/TensorSlotName::INPUT, - /*slot_tensor_role=*/mk_dynamic_tensor_role_fwd(), - /*task_shard=*/std::nullopt, - }, - mk_value(0, TensorSlotName::OUTPUT, src_binding, std::nullopt), - }, - }, - /*node_attrs=*/ - DynamicNodeAttrs{ - /*task_type=*/DynamicTaskType::FWD, - /*device_coords=*/std::nullopt, - /*mapping=*/mapped_task_group, - /*op_attrs=*/TrainingOperationAttrs{ - PCGOperatorAttrs{ - ReplicateAttrs{ - /*replicate_degree=*/2_p, + SUBCASE("fwd") { + OneToMany src_binding{ + {pt1, {mc1}}, + {pt2, {mc2}}, + }; + + OneToMany dst_binding{ + {pt1, {mc1}}, + {pt2, {mc2}}, + {pt3, {mc3}}, + {pt4, {mc4}}, + }; + + DynamicNodeInvocation input = DynamicNodeInvocation{ + /*inputs=*/{ + { + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::INPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_fwd(), + /*task_shard=*/std::nullopt, + }, + mk_value(0, TensorSlotName::OUTPUT, src_binding, std::nullopt), + }, + }, + /*node_attrs=*/ + DynamicNodeAttrs{ + /*task_type=*/DynamicTaskType::FWD, + /*device_coords=*/std::nullopt, + /*mapping=*/mapped_task_group, + /*op_attrs=*/TrainingOperationAttrs{ + PCGOperatorAttrs{ + ReplicateAttrs{ + /*replicate_degree=*/2_p, + }, }, }, - }, - /*layer_guid=*/dynamic_layer_guid_t{parallel_layer_guid_t{Node{20}}}, - /*per_device_op_state=*/std::nullopt, - }, - /*outputs=*/ - { - { - DynamicTensorSlot{ - /*slot_name=*/TensorSlotName::OUTPUT, - /*slot_tensor_role=*/mk_dynamic_tensor_role_fwd(), - /*task_shard=*/std::nullopt, + /*layer_guid=*/dynamic_layer_guid_t{parallel_layer_guid_t{Node{20}}}, + /*per_device_op_state=*/std::nullopt, + }, + /*outputs=*/ + { + { + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::OUTPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_fwd(), + /*task_shard=*/std::nullopt, + }, + mk_value(20, TensorSlotName::OUTPUT, dst_binding, std::nullopt), + }, + }, + }; + + std::unordered_set result = + generate_shard_expansion_for_invocation(input); + + + auto mk_output_binding = [&](MachineSpaceCoordinate const &mc) + -> std::pair + { + return { + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::OUTPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_fwd(), + /*task_shard=*/mc, + }, + DynamicValueAttrsShardingInfo{ + dst_binding.at_r(mc), + one_to_many_filter_keys(dst_binding, + [&](ParallelTensorSpaceCoordinate const &pt_coord) -> bool { + return pt_coord == dst_binding.at_r(mc); + }), + }, + }; + }; + + auto mk_invocation_shard = + [&](nonempty_set const &device_coords, + ParallelTensorSpaceCoordinate const &input_shard_coord, + std::unordered_set const &output_task_shards) + -> DynamicNodeInvocationShardingInfo { + + return DynamicNodeInvocationShardingInfo{ + /*device_coords=*/device_coords, + /*value_sharding=*/ + binary_merge_disjoint_maps( + std::map{ + { + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::INPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_fwd(), + /*task_shard=*/std::nullopt, + }, + DynamicValueAttrsShardingInfo{ + input_shard_coord, + one_to_many_filter_keys( + src_binding, + [&](ParallelTensorSpaceCoordinate const &pt_coord) -> bool { + return pt_coord == input_shard_coord; + }), + }, }, - mk_value(20, TensorSlotName::OUTPUT, dst_binding, std::nullopt), - }, - }, - }; + }, + map_from_pairs(transform(output_task_shards, mk_output_binding))), + }; + }; - std::unordered_set result = - generate_shard_expansion_for_invocation(input); + std::unordered_set correct = { + mk_invocation_shard(nonempty_set{mc1, mc2}, pt1, {mc1, mc2}), + mk_invocation_shard(nonempty_set{mc3, mc4}, pt2, {mc3, mc4}), + }; + CHECK(result.size() == correct.size()); + CHECK(result == correct); + } - auto mk_output_binding = [&](MachineSpaceCoordinate const &mc) - -> std::pair - { - return { - DynamicTensorSlot{ - /*slot_name=*/TensorSlotName::OUTPUT, - /*slot_tensor_role=*/mk_dynamic_tensor_role_fwd(), - /*task_shard=*/mc, - }, - DynamicValueAttrsShardingInfo{ - dst_binding.at_r(mc), - one_to_many_filter_keys(dst_binding, - [&](ParallelTensorSpaceCoordinate const &pt_coord) -> bool { - return pt_coord == dst_binding.at_r(mc); - }), - }, + SUBCASE("bwd") { + OneToMany output_grad_binding{ + {pt1, {mc1}}, + {pt2, {mc2}}, + {pt3, {mc3}}, + {pt4, {mc4}}, }; - }; - auto mk_invocation_shard = - [&](nonempty_set const &device_coords, - ParallelTensorSpaceCoordinate const &input_shard_coord, - std::unordered_set const &output_task_shards) - -> DynamicNodeInvocationShardingInfo { + OneToMany input_grad_binding{ + {pt1, {mc1}}, + {pt2, {mc2}}, + }; - return DynamicNodeInvocationShardingInfo{ - /*device_coords=*/device_coords, - /*value_sharding=*/ - binary_merge_disjoint_maps( - std::map{ + DynamicNodeInvocation input = DynamicNodeInvocation{ + /*inputs=*/{ { - DynamicTensorSlot{ - /*slot_name=*/TensorSlotName::INPUT, - /*slot_tensor_role=*/mk_dynamic_tensor_role_fwd(), - /*task_shard=*/std::nullopt, - }, - DynamicValueAttrsShardingInfo{ - input_shard_coord, - one_to_many_filter_keys( - src_binding, - [&](ParallelTensorSpaceCoordinate const &pt_coord) -> bool { - return pt_coord == input_shard_coord; - }), + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::OUTPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_bwd(), + /*task_shard=*/std::nullopt, + }, + mk_value(0, TensorSlotName::OUTPUT, output_grad_binding, std::nullopt), + }, + }, + /*node_attrs=*/ + DynamicNodeAttrs{ + /*task_type=*/DynamicTaskType::BWD, + /*device_coords=*/std::nullopt, + /*mapping=*/mapped_task_group, + /*op_attrs=*/TrainingOperationAttrs{ + PCGOperatorAttrs{ + ReplicateAttrs{ + /*replicate_degree=*/2_p, + }, }, }, - }, - map_from_pairs(transform(output_task_shards, mk_output_binding))), + /*layer_guid=*/dynamic_layer_guid_t{parallel_layer_guid_t{Node{20}}}, + /*per_device_op_state=*/std::nullopt, + }, + /*outputs=*/ + { + { + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::INPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_bwd(), + /*task_shard=*/std::nullopt, + }, + mk_value(20, TensorSlotName::INPUT, input_grad_binding, std::nullopt), + }, + }, }; - }; - std::unordered_set correct = { - mk_invocation_shard(nonempty_set{mc1, mc2}, pt1, {mc1, mc2}), - mk_invocation_shard(nonempty_set{mc3, mc4}, pt2, {mc3, mc4}), - }; + std::unordered_set result = + generate_shard_expansion_for_invocation(input); + + auto mk_output_grad_binding = [&](MachineSpaceCoordinate const &mc) + -> std::pair + { + return { + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::OUTPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_bwd(), + /*task_shard=*/mc, + }, + DynamicValueAttrsShardingInfo{ + output_grad_binding.at_r(mc), + one_to_many_filter_keys(output_grad_binding, + [&](ParallelTensorSpaceCoordinate const &pt_coord) -> bool { + return pt_coord == output_grad_binding.at_r(mc); + }), + }, + }; + }; - nlohmann::json result_json = result; - nlohmann::json correct_json = correct; + auto mk_invocation_shard = + [&](nonempty_set const &device_coords, + std::unordered_set const &output_grad_task_shards, + ParallelTensorSpaceCoordinate const &input_grad_shard_coord) + -> DynamicNodeInvocationShardingInfo { + + return DynamicNodeInvocationShardingInfo{ + /*device_coords=*/device_coords, + /*value_sharding=*/ + binary_merge_disjoint_maps( + std::map{ + { + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::INPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_bwd(), + /*task_shard=*/std::nullopt, + }, + DynamicValueAttrsShardingInfo{ + input_grad_shard_coord, + one_to_many_filter_keys( + input_grad_binding, + [&](ParallelTensorSpaceCoordinate const &pt_coord) -> bool { + return pt_coord == input_grad_shard_coord; + }), + }, + }, + }, + map_from_pairs(transform(output_grad_task_shards, mk_output_grad_binding))), + }; + }; - CHECK(result.size() == correct.size()); - CHECK(result_json == correct_json); - CHECK(result == correct); + std::unordered_set correct = { + mk_invocation_shard(nonempty_set{mc1, mc2}, {mc1, mc2}, pt1), + mk_invocation_shard(nonempty_set{mc3, mc4}, {mc3, mc4}, pt2), + }; + + CHECK(result.size() == correct.size()); + CHECK(result == correct); + } } } } diff --git a/lib/utils/include/utils/bidict/bidict.h b/lib/utils/include/utils/bidict/bidict.h index 57f8d5e213..7fcc59f116 100644 --- a/lib/utils/include/utils/bidict/bidict.h +++ b/lib/utils/include/utils/bidict/bidict.h @@ -16,6 +16,7 @@ #include "utils/containers/require_same.h" #include "utils/containers/values.h" #include "utils/containers/unordered_set_of.h" +#include "utils/containers/contains_key.h" namespace FlexFlow { @@ -108,10 +109,12 @@ struct bidict { } R const &at_l(L const &l) const { + ASSERT(contains_key(this->fwd_map, l)); return fwd_map.at(l); } L const &at_r(R const &r) const { + ASSERT(contains_key(this->bwd_map, r)); return bwd_map.at(r); } From d7cce600fc2c5c7a4e4615f442e56c4de4737ee5 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Fri, 29 May 2026 03:41:35 -0700 Subject: [PATCH 19/19] Fix other test suites --- .../src/local-execution/task_execution.cc | 5 +- .../utils/containers/map_keys_and_values.h | 21 ++++ .../utils/containers/map_keys_and_values.cc | 8 ++ lib/utils/test/src/utils/bidict/bidict.cc | 4 +- .../containers/binary_merge_disjoint_maps.cc | 10 +- .../binary_merge_disjoint_unordered_maps.cc | 34 ++++++ .../containers/binary_merge_maps_with.cc | 42 +++---- .../binary_merge_maps_with_left_dominating.cc | 10 +- ...binary_merge_maps_with_right_dominating.cc | 10 +- .../binary_merge_unordered_maps_with.cc | 110 ++++++++++++++++++ ...rge_unordered_maps_with_left_dominating.cc | 31 +++++ ...ge_unordered_maps_with_right_dominating.cc | 31 +++++ .../src/utils/containers/map_from_pairs.cc | 15 ++- .../utils/containers/merge_disjoint_maps.cc | 24 ++-- .../merge_disjoint_unordered_maps.cc | 78 +++++++++++++ .../src/utils/containers/merge_maps_with.cc | 36 +++--- .../containers/merge_unordered_maps_with.cc | 100 ++++++++++++++++ 17 files changed, 491 insertions(+), 78 deletions(-) create mode 100644 lib/utils/test/src/utils/containers/binary_merge_disjoint_unordered_maps.cc create mode 100644 lib/utils/test/src/utils/containers/binary_merge_unordered_maps_with.cc create mode 100644 lib/utils/test/src/utils/containers/binary_merge_unordered_maps_with_left_dominating.cc create mode 100644 lib/utils/test/src/utils/containers/binary_merge_unordered_maps_with_right_dominating.cc create mode 100644 lib/utils/test/src/utils/containers/merge_disjoint_unordered_maps.cc create mode 100644 lib/utils/test/src/utils/containers/merge_unordered_maps_with.cc diff --git a/lib/local-execution/src/local-execution/task_execution.cc b/lib/local-execution/src/local-execution/task_execution.cc index c96c834d4a..68a4c579f8 100644 --- a/lib/local-execution/src/local-execution/task_execution.cc +++ b/lib/local-execution/src/local-execution/task_execution.cc @@ -14,6 +14,7 @@ #include "utils/optional.h" #include "utils/overload.h" #include +#include "utils/containers/unordered_map_from_map.h" namespace FlexFlow { @@ -58,9 +59,9 @@ TaskArgumentAccessor make_task_argument_accessor_for_invocation( return assert_unwrap(value.accessor); }; std::unordered_map - tensor_slots_backing = binary_merge_disjoint_maps( + tensor_slots_backing = unordered_map_from_map(binary_merge_disjoint_maps( map_keys_and_values(invocation.inputs, make_param, get_accessor), - map_keys_and_values(invocation.outputs, make_param, get_accessor)); + map_keys_and_values(invocation.outputs, make_param, get_accessor))); return TaskArgumentAccessor::create( /*allocator=*/allocator, diff --git a/lib/utils/include/utils/containers/map_keys_and_values.h b/lib/utils/include/utils/containers/map_keys_and_values.h index 70b7e17103..1873421e17 100644 --- a/lib/utils/include/utils/containers/map_keys_and_values.h +++ b/lib/utils/include/utils/containers/map_keys_and_values.h @@ -3,6 +3,7 @@ #include #include +#include namespace FlexFlow { @@ -26,6 +27,26 @@ std::unordered_map map_keys_and_values( return result; } +template , + typename V2 = std::invoke_result_t> +std::map map_keys_and_values( + std::map const &m, FK const &fk, FV const &fv) { + + std::map result; + for (auto const &kv : m) { + result.insert({fk(kv.first), fv(kv.second)}); + } + + ASSERT(m.size() == result.size(), + "keys passed to map_keys must be transformed into distinct keys"); + + return result; +} + } // namespace FlexFlow #endif diff --git a/lib/utils/src/utils/containers/map_keys_and_values.cc b/lib/utils/src/utils/containers/map_keys_and_values.cc index b3b306988e..95608a09bd 100644 --- a/lib/utils/src/utils/containers/map_keys_and_values.cc +++ b/lib/utils/src/utils/containers/map_keys_and_values.cc @@ -1,5 +1,6 @@ #include "utils/containers/map_keys_and_values.h" #include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { @@ -13,4 +14,11 @@ using FV = std::function; template std::unordered_map map_keys_and_values( std::unordered_map const &, FK const &, FV const &); +using OK = ordered_value_type<0>; +using OK2 = ordered_value_type<1>; +using OFK = std::function; + +template std::map map_keys_and_values( + std::map const &, OFK const &, FV const &); + } // namespace FlexFlow diff --git a/lib/utils/test/src/utils/bidict/bidict.cc b/lib/utils/test/src/utils/bidict/bidict.cc index f15f15b0fe..ead45fe86a 100644 --- a/lib/utils/test/src/utils/bidict/bidict.cc +++ b/lib/utils/test/src/utils/bidict/bidict.cc @@ -63,14 +63,14 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("bidict::erase_l") { dict.erase_l(1); CHECK(dict.size() == 1); - CHECK_THROWS_AS(dict.at_l(1), std::out_of_range); + CHECK_THROWS(dict.at_l(1)); CHECK(dict.at_r("two") == 2); } SUBCASE("bidict::erase_r") { dict.erase_r("one"); CHECK(dict.size() == 1); - CHECK_THROWS_AS(dict.at_r("one"), std::out_of_range); + CHECK_THROWS(dict.at_r("one")); CHECK(dict.at_l(2) == "two"); } diff --git a/lib/utils/test/src/utils/containers/binary_merge_disjoint_maps.cc b/lib/utils/test/src/utils/containers/binary_merge_disjoint_maps.cc index bcc7b4149f..d4487343f2 100644 --- a/lib/utils/test/src/utils/containers/binary_merge_disjoint_maps.cc +++ b/lib/utils/test/src/utils/containers/binary_merge_disjoint_maps.cc @@ -1,27 +1,27 @@ #include "utils/containers/binary_merge_disjoint_maps.h" -#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/map.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("binary_merge_disjoint_maps") { - std::unordered_map l_map = { + std::map l_map = { {1, "one"}, {2, "two"}, }; - std::unordered_map r_map = { + std::map r_map = { {3, "three"}, }; - std::unordered_map correct = { + std::map correct = { {1, "one"}, {2, "two"}, {3, "three"}, }; SUBCASE("maps are disjoint") { - std::unordered_map result = + std::map result = binary_merge_disjoint_maps(l_map, r_map); CHECK(result == correct); diff --git a/lib/utils/test/src/utils/containers/binary_merge_disjoint_unordered_maps.cc b/lib/utils/test/src/utils/containers/binary_merge_disjoint_unordered_maps.cc new file mode 100644 index 0000000000..250d1c7f69 --- /dev/null +++ b/lib/utils/test/src/utils/containers/binary_merge_disjoint_unordered_maps.cc @@ -0,0 +1,34 @@ +#include "utils/containers/binary_merge_disjoint_unordered_maps.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("binary_merge_disjoint_unordered_maps") { + std::unordered_map l_map = { + {1, "one"}, + {2, "two"}, + }; + + std::unordered_map r_map = { + {3, "three"}, + }; + + std::unordered_map correct = { + {1, "one"}, + {2, "two"}, + {3, "three"}, + }; + SUBCASE("maps are disjoint") { + std::unordered_map result = + binary_merge_disjoint_unordered_maps(l_map, r_map); + + CHECK(result == correct); + } + + SUBCASE("maps are not disjoint") { + CHECK_THROWS(binary_merge_disjoint_unordered_maps(l_map, l_map)); + } + } +} diff --git a/lib/utils/test/src/utils/containers/binary_merge_maps_with.cc b/lib/utils/test/src/utils/containers/binary_merge_maps_with.cc index 55b9c428bf..6a848565e1 100644 --- a/lib/utils/test/src/utils/containers/binary_merge_maps_with.cc +++ b/lib/utils/test/src/utils/containers/binary_merge_maps_with.cc @@ -1,5 +1,5 @@ #include "utils/containers/binary_merge_maps_with.h" -#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/map.h" #include #include @@ -11,47 +11,47 @@ TEST_SUITE(FF_TEST_SUITE) { std::string const &) -> std::string { PANIC(); }; SUBCASE("lhs and rhs do not overlap") { - std::unordered_map lhs = { + std::map lhs = { {1, "lhs_one."}, {4, "lhs_four."}, }; - std::unordered_map rhs = { + std::map rhs = { {2, "rhs_two."}, {5, "rhs_five."}, }; - std::unordered_map correct = { + std::map correct = { {1, "lhs_one."}, {2, "rhs_two."}, {4, "lhs_four."}, {5, "rhs_five."}, }; - std::unordered_map result = + std::map result = binary_merge_maps_with(lhs, rhs, fail_if_called); CHECK(result == correct); } SUBCASE("lhs and rhs overlap") { - std::unordered_map lhs = { + std::map lhs = { {1, "lhs_one."}, {4, "lhs_four."}, }; - std::unordered_map rhs = { + std::map rhs = { {2, "rhs_two."}, {4, "rhs_four."}, {5, "rhs_five."}, }; - std::unordered_map result = binary_merge_maps_with( + std::map result = binary_merge_maps_with( lhs, rhs, [](std::string const &l, std::string const &r) { return l + r; }); - std::unordered_map correct = { + std::map correct = { {1, "lhs_one."}, {2, "rhs_two."}, {4, "lhs_four.rhs_four."}, @@ -62,47 +62,47 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("lhs is empty") { - std::unordered_map lhs = {}; + std::map lhs = {}; - std::unordered_map rhs = { + std::map rhs = { {2, "rhs_two."}, {4, "rhs_four."}, {5, "rhs_five."}, }; - std::unordered_map result = + std::map result = binary_merge_maps_with(lhs, rhs, fail_if_called); - std::unordered_map correct = rhs; + std::map correct = rhs; CHECK(result == correct); } SUBCASE("rhs is empty") { - std::unordered_map lhs = { + std::map lhs = { {1, "lhs_one."}, {4, "lhs_four."}, }; - std::unordered_map rhs = {}; + std::map rhs = {}; - std::unordered_map result = + std::map result = binary_merge_maps_with(lhs, rhs, fail_if_called); - std::unordered_map correct = lhs; + std::map correct = lhs; CHECK(result == correct); } SUBCASE("both lhs and rhs are empty") { - std::unordered_map lhs = {}; + std::map lhs = {}; - std::unordered_map rhs = {}; + std::map rhs = {}; - std::unordered_map result = + std::map result = binary_merge_maps_with(lhs, rhs, fail_if_called); - std::unordered_map correct = {}; + std::map correct = {}; CHECK(result == correct); } diff --git a/lib/utils/test/src/utils/containers/binary_merge_maps_with_left_dominating.cc b/lib/utils/test/src/utils/containers/binary_merge_maps_with_left_dominating.cc index 27a389d400..fe152dd832 100644 --- a/lib/utils/test/src/utils/containers/binary_merge_maps_with_left_dominating.cc +++ b/lib/utils/test/src/utils/containers/binary_merge_maps_with_left_dominating.cc @@ -1,5 +1,5 @@ #include "utils/containers/binary_merge_maps_with_left_dominating.h" -#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/map.h" #include #include @@ -7,23 +7,23 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("binary_merge_maps_with_left_dominating") { - std::unordered_map l_map = { + std::map l_map = { {1, "one"}, {2, "left_two"}, }; - std::unordered_map r_map = { + std::map r_map = { {2, "right_two"}, {3, "three"}, }; - std::unordered_map correct = { + std::map correct = { {1, "one"}, {2, "left_two"}, {3, "three"}, }; - std::unordered_map result = + std::map result = binary_merge_maps_with_left_dominating(l_map, r_map); CHECK(result == correct); diff --git a/lib/utils/test/src/utils/containers/binary_merge_maps_with_right_dominating.cc b/lib/utils/test/src/utils/containers/binary_merge_maps_with_right_dominating.cc index 153266989e..c107f2b7ff 100644 --- a/lib/utils/test/src/utils/containers/binary_merge_maps_with_right_dominating.cc +++ b/lib/utils/test/src/utils/containers/binary_merge_maps_with_right_dominating.cc @@ -1,5 +1,5 @@ #include "utils/containers/binary_merge_maps_with_right_dominating.h" -#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/map.h" #include #include @@ -7,23 +7,23 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("binary_merge_maps_with_right_dominating") { - std::unordered_map l_map = { + std::map l_map = { {1, "one"}, {2, "left_two"}, }; - std::unordered_map r_map = { + std::map r_map = { {2, "right_two"}, {3, "three"}, }; - std::unordered_map correct = { + std::map correct = { {1, "one"}, {2, "right_two"}, {3, "three"}, }; - std::unordered_map result = + std::map result = binary_merge_maps_with_right_dominating(l_map, r_map); CHECK(result == correct); diff --git a/lib/utils/test/src/utils/containers/binary_merge_unordered_maps_with.cc b/lib/utils/test/src/utils/containers/binary_merge_unordered_maps_with.cc new file mode 100644 index 0000000000..4e825b99e2 --- /dev/null +++ b/lib/utils/test/src/utils/containers/binary_merge_unordered_maps_with.cc @@ -0,0 +1,110 @@ +#include "utils/containers/binary_merge_unordered_maps_with.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("binary_merge_unordered_maps_with") { + auto fail_if_called = [](std::string const &, + std::string const &) -> std::string { PANIC(); }; + + SUBCASE("lhs and rhs do not overlap") { + std::unordered_map lhs = { + {1, "lhs_one."}, + {4, "lhs_four."}, + }; + + std::unordered_map rhs = { + {2, "rhs_two."}, + {5, "rhs_five."}, + }; + + std::unordered_map correct = { + {1, "lhs_one."}, + {2, "rhs_two."}, + {4, "lhs_four."}, + {5, "rhs_five."}, + }; + + std::unordered_map result = + binary_merge_unordered_maps_with(lhs, rhs, fail_if_called); + + CHECK(result == correct); + } + + SUBCASE("lhs and rhs overlap") { + std::unordered_map lhs = { + {1, "lhs_one."}, + {4, "lhs_four."}, + }; + + std::unordered_map rhs = { + {2, "rhs_two."}, + {4, "rhs_four."}, + {5, "rhs_five."}, + }; + + std::unordered_map result = binary_merge_unordered_maps_with( + lhs, rhs, [](std::string const &l, std::string const &r) { + return l + r; + }); + + std::unordered_map correct = { + {1, "lhs_one."}, + {2, "rhs_two."}, + {4, "lhs_four.rhs_four."}, + {5, "rhs_five."}, + }; + + CHECK(result == correct); + } + + SUBCASE("lhs is empty") { + std::unordered_map lhs = {}; + + std::unordered_map rhs = { + {2, "rhs_two."}, + {4, "rhs_four."}, + {5, "rhs_five."}, + }; + + std::unordered_map result = + binary_merge_unordered_maps_with(lhs, rhs, fail_if_called); + + std::unordered_map correct = rhs; + + CHECK(result == correct); + } + + SUBCASE("rhs is empty") { + std::unordered_map lhs = { + {1, "lhs_one."}, + {4, "lhs_four."}, + }; + + std::unordered_map rhs = {}; + + std::unordered_map result = + binary_merge_unordered_maps_with(lhs, rhs, fail_if_called); + + std::unordered_map correct = lhs; + + CHECK(result == correct); + } + + SUBCASE("both lhs and rhs are empty") { + std::unordered_map lhs = {}; + + std::unordered_map rhs = {}; + + std::unordered_map result = + binary_merge_unordered_maps_with(lhs, rhs, fail_if_called); + + std::unordered_map correct = {}; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/binary_merge_unordered_maps_with_left_dominating.cc b/lib/utils/test/src/utils/containers/binary_merge_unordered_maps_with_left_dominating.cc new file mode 100644 index 0000000000..d857cf2d91 --- /dev/null +++ b/lib/utils/test/src/utils/containers/binary_merge_unordered_maps_with_left_dominating.cc @@ -0,0 +1,31 @@ +#include "utils/containers/binary_merge_unordered_maps_with_left_dominating.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("binary_merge_unordered_maps_with_left_dominating") { + std::unordered_map l_map = { + {1, "one"}, + {2, "left_two"}, + }; + + std::unordered_map r_map = { + {2, "right_two"}, + {3, "three"}, + }; + + std::unordered_map correct = { + {1, "one"}, + {2, "left_two"}, + {3, "three"}, + }; + + std::unordered_map result = + binary_merge_unordered_maps_with_left_dominating(l_map, r_map); + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/binary_merge_unordered_maps_with_right_dominating.cc b/lib/utils/test/src/utils/containers/binary_merge_unordered_maps_with_right_dominating.cc new file mode 100644 index 0000000000..71f50c4dac --- /dev/null +++ b/lib/utils/test/src/utils/containers/binary_merge_unordered_maps_with_right_dominating.cc @@ -0,0 +1,31 @@ +#include "utils/containers/binary_merge_unordered_maps_with_right_dominating.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("binary_merge_unordered_maps_with_right_dominating") { + std::unordered_map l_map = { + {1, "one"}, + {2, "left_two"}, + }; + + std::unordered_map r_map = { + {2, "right_two"}, + {3, "three"}, + }; + + std::unordered_map correct = { + {1, "one"}, + {2, "right_two"}, + {3, "three"}, + }; + + std::unordered_map result = + binary_merge_unordered_maps_with_right_dominating(l_map, r_map); + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/map_from_pairs.cc b/lib/utils/test/src/utils/containers/map_from_pairs.cc index fc387f1d1a..48e8b9fe05 100644 --- a/lib/utils/test/src/utils/containers/map_from_pairs.cc +++ b/lib/utils/test/src/utils/containers/map_from_pairs.cc @@ -1,24 +1,23 @@ #include "utils/containers/map_from_pairs.h" -#include "test/utils/doctest/fmt/unordered_map.h" -#include "utils/hash/pair.h" +#include "test/utils/doctest/fmt/map.h" #include #include +#include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("map_from_pairs") { - std::unordered_set> input = - std::unordered_set>{ + std::set> input = + std::set>{ {1, "one"}, {2, "two"}, }; - std::unordered_map result = map_from_pairs(input); - - std::unordered_map correct = - std::unordered_map{ + std::map result = map_from_pairs(input); + std::map correct = + std::map{ {1, "one"}, {2, "two"}, }; diff --git a/lib/utils/test/src/utils/containers/merge_disjoint_maps.cc b/lib/utils/test/src/utils/containers/merge_disjoint_maps.cc index 24e8d548ae..bf4b2202d7 100644 --- a/lib/utils/test/src/utils/containers/merge_disjoint_maps.cc +++ b/lib/utils/test/src/utils/containers/merge_disjoint_maps.cc @@ -1,37 +1,37 @@ #include "utils/containers/merge_disjoint_maps.h" -#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/map.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("merge_disjoint_maps") { - std::unordered_map m1 = { + std::map m1 = { {4, "four"}, {2, "two"}, }; - std::unordered_map m2 = { + std::map m2 = { {3, "four"}, }; - std::unordered_map m3 = { + std::map m3 = { {1, "one"}, }; - std::unordered_map m4 = {}; + std::map m4 = {}; SUBCASE("maps are disjoint") { - std::vector> input = { + std::vector> input = { m1, m2, m3, m4, }; - std::unordered_map result = merge_disjoint_maps(input); + std::map result = merge_disjoint_maps(input); - std::unordered_map correct = { + std::map correct = { {4, "four"}, {2, "two"}, {3, "four"}, @@ -42,12 +42,12 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("maps are not disjoint") { - std::unordered_map m5 = { + std::map m5 = { {4, "five"}, {6, "six"}, }; - std::vector> input = { + std::vector> input = { m1, m2, m3, @@ -59,12 +59,12 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("maps are not disjoint but have identical values") { - std::unordered_map m5 = { + std::map m5 = { {4, "four"}, {6, "six"}, }; - std::vector> input = { + std::vector> input = { m1, m2, m3, diff --git a/lib/utils/test/src/utils/containers/merge_disjoint_unordered_maps.cc b/lib/utils/test/src/utils/containers/merge_disjoint_unordered_maps.cc new file mode 100644 index 0000000000..ceecea96c1 --- /dev/null +++ b/lib/utils/test/src/utils/containers/merge_disjoint_unordered_maps.cc @@ -0,0 +1,78 @@ +#include "utils/containers/merge_disjoint_unordered_maps.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("merge_disjoint_unordered_maps") { + std::unordered_map m1 = { + {4, "four"}, + {2, "two"}, + }; + + std::unordered_map m2 = { + {3, "four"}, + }; + + std::unordered_map m3 = { + {1, "one"}, + }; + + std::unordered_map m4 = {}; + + SUBCASE("maps are disjoint") { + std::vector> input = { + m1, + m2, + m3, + m4, + }; + + std::unordered_map result = merge_disjoint_unordered_maps(input); + + std::unordered_map correct = { + {4, "four"}, + {2, "two"}, + {3, "four"}, + {1, "one"}, + }; + + CHECK(result == correct); + } + + SUBCASE("maps are not disjoint") { + std::unordered_map m5 = { + {4, "five"}, + {6, "six"}, + }; + + std::vector> input = { + m1, + m2, + m3, + m4, + m5, + }; + + CHECK_THROWS(merge_disjoint_unordered_maps(input)); + } + + SUBCASE("maps are not disjoint but have identical values") { + std::unordered_map m5 = { + {4, "four"}, + {6, "six"}, + }; + + std::vector> input = { + m1, + m2, + m3, + m4, + m5, + }; + + CHECK_THROWS(merge_disjoint_unordered_maps(input)); + } + } +} diff --git a/lib/utils/test/src/utils/containers/merge_maps_with.cc b/lib/utils/test/src/utils/containers/merge_maps_with.cc index ec5b31abf3..fd73a9345e 100644 --- a/lib/utils/test/src/utils/containers/merge_maps_with.cc +++ b/lib/utils/test/src/utils/containers/merge_maps_with.cc @@ -1,5 +1,5 @@ #include "utils/containers/merge_maps_with.h" -#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/map.h" #include "test/utils/rapidcheck.h" #include "utils/containers/binary_merge_maps_with.h" #include @@ -14,37 +14,37 @@ TEST_SUITE(FF_TEST_SUITE) { RC_SUBCASE( "with two inputs, matches binary_merge_maps_with", - [&](std::unordered_map const &lhs, - std::unordered_map const &rhs) { - std::unordered_map from_merge_maps_with = + [&](std::map const &lhs, + std::map const &rhs) { + std::map from_merge_maps_with = merge_maps_with(std::vector{lhs, rhs}, string_concat); - std::unordered_map from_binary_merge_maps_with = + std::map from_binary_merge_maps_with = binary_merge_maps_with(lhs, rhs, string_concat); CHECK(from_merge_maps_with == from_binary_merge_maps_with); }); SUBCASE("maps overlap") { - std::unordered_map map1 = { + std::map map1 = { {1, "map1_one."}, {4, "map1_four."}, }; - std::unordered_map map2 = { + std::map map2 = { {2, "map2_two."}, {4, "map2_four."}, {5, "map2_five."}, }; - std::unordered_map map3 = { + std::map map3 = { {1, "map3_one."}, }; - std::unordered_map result = + std::map result = merge_maps_with(std::vector{map1, map2, map3}, string_concat); - std::unordered_map correct = { + std::map correct = { {1, "map1_one.map3_one."}, {2, "map2_two."}, {4, "map1_four.map2_four."}, @@ -58,24 +58,24 @@ TEST_SUITE(FF_TEST_SUITE) { std::string const &) -> std::string { PANIC(); }; SUBCASE("maps do not overlap") { - std::unordered_map map1 = { + std::map map1 = { {8, "map1_eight."}, {4, "map1_four."}, }; - std::unordered_map map2 = { + std::map map2 = { {2, "map2_two."}, {5, "map2_five."}, }; - std::unordered_map map3 = { + std::map map3 = { {1, "map3_one."}, }; - std::unordered_map result = + std::map result = merge_maps_with(std::vector{map1, map2, map3}, fail_if_called); - std::unordered_map correct = { + std::map correct = { {1, "map3_one."}, {2, "map2_two."}, {4, "map1_four."}, @@ -87,12 +87,12 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("no maps are provided") { - std::vector> maps = {}; + std::vector> maps = {}; - std::unordered_map result = + std::map result = merge_maps_with(maps, fail_if_called); - std::unordered_map correct = {}; + std::map correct = {}; CHECK(result == correct); } diff --git a/lib/utils/test/src/utils/containers/merge_unordered_maps_with.cc b/lib/utils/test/src/utils/containers/merge_unordered_maps_with.cc new file mode 100644 index 0000000000..66827ca453 --- /dev/null +++ b/lib/utils/test/src/utils/containers/merge_unordered_maps_with.cc @@ -0,0 +1,100 @@ +#include "utils/containers/merge_unordered_maps_with.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/rapidcheck.h" +#include "utils/containers/binary_merge_unordered_maps_with.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("merge_unordered_maps_with") { + auto string_concat = [](std::string const &l, std::string const &r) { + return l + r; + }; + + RC_SUBCASE( + "with two inputs, matches binary_merge_unordered_maps_with", + [&](std::unordered_map const &lhs, + std::unordered_map const &rhs) { + std::unordered_map from_merge_unordered_maps_with = + merge_unordered_maps_with(std::vector{lhs, rhs}, string_concat); + + std::unordered_map from_binary_merge_unordered_maps_with = + binary_merge_unordered_maps_with(lhs, rhs, string_concat); + + CHECK(from_merge_unordered_maps_with == from_binary_merge_unordered_maps_with); + }); + + SUBCASE("maps overlap") { + std::unordered_map map1 = { + {1, "map1_one."}, + {4, "map1_four."}, + }; + + std::unordered_map map2 = { + {2, "map2_two."}, + {4, "map2_four."}, + {5, "map2_five."}, + }; + + std::unordered_map map3 = { + {1, "map3_one."}, + }; + + std::unordered_map result = + merge_unordered_maps_with(std::vector{map1, map2, map3}, string_concat); + + std::unordered_map correct = { + {1, "map1_one.map3_one."}, + {2, "map2_two."}, + {4, "map1_four.map2_four."}, + {5, "map2_five."}, + }; + + CHECK(result == correct); + } + + auto fail_if_called = [](std::string const &, + std::string const &) -> std::string { PANIC(); }; + + SUBCASE("maps do not overlap") { + std::unordered_map map1 = { + {8, "map1_eight."}, + {4, "map1_four."}, + }; + + std::unordered_map map2 = { + {2, "map2_two."}, + {5, "map2_five."}, + }; + + std::unordered_map map3 = { + {1, "map3_one."}, + }; + + std::unordered_map result = + merge_unordered_maps_with(std::vector{map1, map2, map3}, fail_if_called); + + std::unordered_map correct = { + {1, "map3_one."}, + {2, "map2_two."}, + {4, "map1_four."}, + {5, "map2_five."}, + {8, "map1_eight."}, + }; + + CHECK(result == correct); + } + + SUBCASE("no maps are provided") { + std::vector> maps = {}; + + std::unordered_map result = + merge_unordered_maps_with(maps, fail_if_called); + + std::unordered_map correct = {}; + + CHECK(result == correct); + } + } +}